Skip to content

Commit 3e18dba

Browse files
HIP: Patch failed testcase in WMMA-MMQ kernels for RDNA 4 (ggml-org#17502)
* patch failed test case MUL_MAT(type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) for enabling WMMA on RDNA4 * Quick clean up on mma.cuh to add ggml_cuda_memcpy_1 back in for half2 and bfloat162
1 parent eeb5605 commit 3e18dba

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

ggml/src/ggml-cuda/mma.cuh

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -437,18 +437,27 @@ namespace ggml_cuda_mma {
437437
xi[0] = xs[0];
438438
}
439439
#elif defined(AMD_WMMA_AVAILABLE)
440-
if constexpr (I == 16 && J == 4) {
441-
int64_t * xi = (int64_t *) t.x;
442-
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
443-
xi[0] = xs[0];
444-
}else if constexpr (I == 16 && J == 8) {
445-
int64_t * xi = (int64_t *) t.x;
446-
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
447-
xi[0] = xs[0];
440+
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
441+
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
442+
443+
} else if constexpr (std::is_same_v<T, int>) {
444+
if constexpr (I == 16 && J == 4) {
445+
int64_t * xi = (int64_t *) t.x;
446+
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
447+
xi[0] = xs[0];
448448

449-
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
450-
xi[1] = xs1[0];
451-
}else{
449+
}else if constexpr (I == 16 && J == 8) {
450+
int64_t * xi = (int64_t *) t.x;
451+
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
452+
xi[0] = xs[0];
453+
454+
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
455+
xi[1] = xs1[0];
456+
457+
}else{
458+
NO_DEVICE_CODE;
459+
}
460+
} else {
452461
NO_DEVICE_CODE;
453462
}
454463
#else

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3701,7 +3701,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
37013701
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
37023702
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
37033703
const size_t nbs_ids = mmq_x*sizeof(int);
3704-
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3704+
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
37053705
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
37063706
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
37073707
}

0 commit comments

Comments
 (0)