Skip to content

Commit e3b7c57

Browse files
__shfl_sync workaround for movmatrix
1 parent 60958f6 commit e3b7c57

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

ggml/src/ggml-cuda/mma.cuh

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,44 @@
1515

1616
#include "common.cuh"
1717

18+
19+
#if CUDART_VERSION >= 12000
20+
21+
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
22+
int ret = 0;
23+
24+
#ifdef NEW_MMA_AVAILABLE
25+
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
26+
: "+r"(ret) : "r"(x));
27+
#else
28+
NO_DEVICE_CODE;
29+
#endif // defined(NEW_MMA_AVAILABLE)
30+
return ret;
31+
}
32+
33+
#else
34+
35+
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
36+
// Imagine transposing row-major matrix to column-major matrix.
37+
const int src_i_low = 2 * (threadIdx.x % 4);
38+
const int src_i_high = src_i_low + 1;
39+
const int src_j = threadIdx.x / 4;
40+
41+
const int src_laneid_low = src_i_low * 4 + src_j / 2;
42+
const int src_laneid_high = src_i_high * 4 + src_j / 2;
43+
44+
const int shift_low = ((src_j + 0) % 2) * 16;
45+
const int shift_high = ((src_j + 1) % 2) * 16;
46+
47+
const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
48+
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
49+
50+
return ret_low | ret_high;
51+
}
52+
53+
#endif // CUDART_VERSION >= 12000
54+
55+
1856
template <typename T>
1957
struct mma_A_I16K4 {
2058
static_assert(sizeof(T) == 4, "bad type size");
@@ -119,21 +157,14 @@ struct mma_A_I16K8 {
119157
}
120158

121159
__device__ __forceinline__ void transpose() {
122-
#ifdef NEW_MMA_AVAILABLE
123160
int * xi = (int *) x;
124-
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
125-
: "+r"(xi[0]) : "r"(xi[0]));
126-
int tmp = 0;
127-
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
128-
: "+r"(tmp) : "r"(xi[1]));
129-
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
130-
: "+r"(xi[1]) : "r"(xi[2]));
161+
xi[0] = ggml_cuda_movmatrix(xi[0]);
162+
163+
const int tmp = ggml_cuda_movmatrix(xi[1]);
164+
xi[1] = ggml_cuda_movmatrix(xi[2]);
131165
xi[2] = tmp;
132-
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
133-
: "+r"(xi[3]) : "r"(xi[3]));
134-
#else
135-
NO_DEVICE_CODE;
136-
#endif // NEW_MMA_AVAILABLE
166+
167+
xi[3] = ggml_cuda_movmatrix(xi[3]);
137168
}
138169
};
139170

@@ -350,16 +381,10 @@ struct mma_C_I16J8<half2> {
350381
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
351382
mma_B_J8K8<half2> mma_B;
352383

353-
#ifdef NEW_MMA_AVAILABLE
354384
int * xi = (int *) x;
355385
int * Bxi = (int *) mma_B.x;
356-
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
357-
: "+r"(Bxi[0]) : "r"(xi[0]));
358-
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
359-
: "+r"(Bxi[1]) : "r"(xi[1]));
360-
#else
361-
NO_DEVICE_CODE;
362-
#endif // NEW_MMA_AVAILABLE
386+
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
387+
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
363388

364389
return mma_B;
365390
}
@@ -417,15 +442,9 @@ struct mma_C_I16J8<float> {
417442
mma_B.x[0] = make_half2(x[0], x[1]);
418443
mma_B.x[1] = make_half2(x[2], x[3]);
419444

420-
#ifdef NEW_MMA_AVAILABLE
421445
int * Bxi = (int *) mma_B.x;
422-
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;"
423-
: "+r"(Bxi[0]) : );
424-
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;"
425-
: "+r"(Bxi[1]) : );
426-
#else
427-
NO_DEVICE_CODE;
428-
#endif // NEW_MMA_AVAILABLE
446+
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
447+
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
429448

430449
return mma_B;
431450
}

0 commit comments

Comments
 (0)