Skip to content

Commit d918041

Browse files
committed
Address review comments
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent bc30598 commit d918041

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,18 +201,14 @@ typedef float2 dfloat2;
201201
#define FAST_FP16_AVAILABLE
202202
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
203203

204-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
204+
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
205205
#define FP16_MMA_AVAILABLE
206-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
206+
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
207207

208208
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
209209
#define FP16_MMA_AVAILABLE
210210
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
211211

212-
#if defined(GGML_USE_MUSA)
213-
#define FP16_MMA_AVAILABLE
214-
#endif // defined(GGML_USE_MUSA)
215-
216212
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
217213
#define NEW_MMA_AVAILABLE
218214
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,9 +1202,14 @@ static void ggml_cuda_op_mul_mat_cublas(
12021202

12031203
const int cc = ggml_cuda_info().devices[id].cc;
12041204

1205+
const bool support_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1206+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
1207+
1208+
const bool support_fp16 = (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
1209+
GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
12051210
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
12061211

1207-
if ((GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2)) && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1212+
if (support_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
12081213
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
12091214
if (src1->type != GGML_TYPE_BF16) {
12101215
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1232,7 +1237,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12321237

12331238
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
12341239
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1235-
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2)) && use_fp16) {
1240+
} else if (support_fp16 && use_fp16) {
12361241
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
12371242
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
12381243
if (src0->type != GGML_TYPE_F16) {

0 commit comments

Comments
 (0)