Skip to content

Commit fecf82c

Browse files
committed
Address review comments
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 720425f commit fecf82c

File tree

4 files changed

+16
-15
lines changed

4 files changed

+16
-15
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#define GGML_CUDA_CC_ADA_LOVELACE 890
5050
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
5151
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
52+
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
5253

5354
// AMD
5455
// GCN/CNDA, wave size is 64
@@ -79,7 +80,7 @@
7980
#define GGML_CUDA_CC_QY2 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
8081
#define GGML_CUDA_CC_NG (GGML_MUSA_CC_OFFSET_MTHREADS + 0x310) // TBD
8182

82-
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS)
83+
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
8384
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
8485
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NEXT)
8586
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
@@ -229,33 +230,33 @@ static bool fp16_available(const int cc) {
229230
}
230231

231232
static bool fast_fp16_available(const int cc) {
232-
return (!GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
233+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
233234
}
234235

235236
// To be used for feature selection of external libraries, e.g. cuBLAS.
236237
static bool fast_fp16_hardware_available(const int cc) {
237-
return (!GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
238+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
238239
}
239240

240241
// Any FP16 tensor core instructions are available for ggml code.
241242
static bool fp16_mma_available(const int cc) {
242243
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
243244
return false;
244245
#else
245-
return !GGML_CUDA_CC_IS_MTHREADS(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
246+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
246247
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
247248
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
248249
}
249250

250251
// To be used for feature selection of external libraries, e.g. cuBLAS.
251252
static bool fp16_mma_hardware_available(const int cc) {
252-
return !GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_VOLTA ||
253+
return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA ||
253254
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
254255
}
255256

256257
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
257258
static bool new_mma_available(const int cc) {
258-
return !GGML_CUDA_CC_IS_MTHREADS(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
259+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
259260
}
260261

261262
static bool cp_async_available(const int cc) {
@@ -433,13 +434,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
433434

434435
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
435436

436-
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
437+
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
437438
return __dp4a(a, b, c);
438-
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
439+
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
439440
const int8_t * a8 = (const int8_t *) &a;
440441
const int8_t * b8 = (const int8_t *) &b;
441442
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
442-
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
443+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
443444

444445
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
445446
}

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,7 @@ static void ggml_cuda_op_mul_mat_cublas(
11921192

11931193
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
11941194

1195-
if (((cc >= GGML_CUDA_CC_VOLTA && !GGML_CUDA_CC_IS_MTHREADS(cc)) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
1195+
if (((cc >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
11961196
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
11971197
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
11981198
if (src0->type != GGML_TYPE_F16) {

ggml/src/ggml-cuda/mmq.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void ggml_cuda_op_mul_mat_q(
2828
// Also its fixup needs to allocate a temporary buffer in the memory pool.
2929
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
3030
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
31-
!GGML_CUDA_CC_IS_MTHREADS(cc) && src1_ncols == ne11;
31+
GGML_CUDA_CC_IS_NVIDIA(cc) && src1_ncols == ne11;
3232
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
3333

3434
switch (src0->type) {
@@ -145,7 +145,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
145145
return true;
146146
#endif //GGML_CUDA_FORCE_MMQ
147147

148-
if (!GGML_CUDA_CC_IS_MTHREADS(cc)) {
148+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
149149
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
150150
}
151151

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct tile_x_sizes {
9090

9191
static int get_mmq_x_max_host(const int cc) {
9292
return new_mma_available(cc) ? 128 :
93-
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && !GGML_CUDA_CC_IS_MTHREADS(cc) ?
93+
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc) ?
9494
#ifdef GGML_CUDA_FORCE_MMQ
9595
128 : 64;
9696
#else
@@ -124,7 +124,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
124124

125125
static int get_mmq_y_host(const int cc) {
126126
return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
127-
((ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && !GGML_CUDA_CC_IS_MTHREADS(cc)) ? 128 : 64);
127+
((ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) ? 128 : 64);
128128
}
129129

130130
static constexpr __device__ int get_mmq_y_device() {
@@ -2832,7 +2832,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
28322832
const int mmq_x_max = get_mmq_x_max_host(cc);
28332833
const int mmq_y = get_mmq_y_host(cc);
28342834
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
2835-
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && !GGML_CUDA_CC_IS_MTHREADS(cc);
2835+
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc);
28362836

28372837
int mmq_x_best = 0;
28382838
int nparts_best = INT_MAX;

0 commit comments

Comments
 (0)