Skip to content

Commit 5ab1491

Browse files
committed
Feat: Enable stream-k for CDNA3
1 parent ba17f62 commit 5ab1491

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

ggml/src/ggml-cuda/mmq.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ void ggml_cuda_mul_mat_q(
109109
const int64_t s03 = src0->nb[3] / ts_src0;
110110
const int64_t s3 = dst->nb[3] / ts_dst;
111111

112-
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
112+
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
113+
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)));
113114

114115
if (!ids) {
115116
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
@@ -250,8 +251,9 @@ void ggml_cuda_op_mul_mat_q(
250251
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
251252
// Also its fixup needs to allocate a temporary buffer in the memory pool.
252253
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
253-
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
254-
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
254+
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
255+
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)))
256+
&& src1_ncols == ne11;
255257
const mmq_args args = {
256258
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
257259
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3182,7 +3182,7 @@ static __global__ void mul_mat_q(
31823182
__syncthreads();
31833183

31843184
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
3185-
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3185+
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(AMD_MMA_AVAILABLE)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
31863186
{
31873187
const int wt = blockIdx.z / nchannels_y;
31883188
const int zt = blockIdx.z - wt*nchannels_y;
@@ -3236,7 +3236,7 @@ static __global__ void mul_mat_q(
32363236
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
32373237
return;
32383238
}
3239-
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3239+
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(AMD_MMA_AVAILABLE)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
32403240

32413241
const int64_t blocks_per_ne00 = ncols_x / qk;
32423242
constexpr int blocks_per_iter = MMQ_ITER_K / qk;

0 commit comments

Comments
 (0)