Skip to content

Commit 09a875e

Browse files
committed
Address review comments
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 876ee4c commit 09a875e

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,18 +201,14 @@ typedef float2 dfloat2;
201201
#define FAST_FP16_AVAILABLE
202202
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
203203

204-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
204+
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
205205
#define FP16_MMA_AVAILABLE
206-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
206+
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
207207

208208
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
209209
#define FP16_MMA_AVAILABLE
210210
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
211211

212-
#if defined(GGML_USE_MUSA)
213-
#define FP16_MMA_AVAILABLE
214-
#endif // defined(GGML_USE_MUSA)
215-
216212
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
217213
#define NEW_MMA_AVAILABLE
218214
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,9 +1227,14 @@ static void ggml_cuda_op_mul_mat_cublas(
12271227

12281228
const int cc = ggml_cuda_info().devices[id].cc;
12291229

1230+
const bool support_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1231+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
1232+
1233+
const bool support_fp16 = (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
1234+
GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
12301235
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
12311236

1232-
if ((GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2)) && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1237+
if (support_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
12331238
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
12341239
if (src1->type != GGML_TYPE_BF16) {
12351240
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1257,7 +1262,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12571262

12581263
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
12591264
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1260-
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2)) && use_fp16) {
1265+
} else if (support_fp16 && use_fp16) {
12611266
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
12621267
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
12631268
if (src0->type != GGML_TYPE_F16) {

0 commit comments

Comments
 (0)