|
15 | 15 |
|
16 | 16 | #include "common.cuh" |
17 | 17 |
|
| 18 | + |
| 19 | +#if CUDART_VERSION >= 12000 |
| 20 | + |
| 21 | +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { |
| 22 | + int ret = 0; |
| 23 | + |
| 24 | +#ifdef NEW_MMA_AVAILABLE |
| 25 | + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" |
| 26 | + : "+r"(ret) : "r"(x)); |
| 27 | +#else |
| 28 | + NO_DEVICE_CODE; |
| 29 | +#endif // defined(NEW_MMA_AVAILABLE) |
| 30 | + return ret; |
| 31 | +} |
| 32 | + |
| 33 | +#else |
| 34 | + |
| 35 | +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { |
| 36 | + // Imagine transposing row-major matrix to column-major matrix. |
| 37 | + const int src_i_low = 2 * (threadIdx.x % 4); |
| 38 | + const int src_i_high = src_i_low + 1; |
| 39 | + const int src_j = threadIdx.x / 4; |
| 40 | + |
| 41 | + const int src_laneid_low = src_i_low * 4 + src_j / 2; |
| 42 | + const int src_laneid_high = src_i_high * 4 + src_j / 2; |
| 43 | + |
| 44 | + const int shift_low = ((src_j + 0) % 2) * 16; |
| 45 | + const int shift_high = ((src_j + 1) % 2) * 16; |
| 46 | + |
| 47 | + const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; |
| 48 | + const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; |
| 49 | + |
| 50 | + return ret_low | ret_high; |
| 51 | +} |
| 52 | + |
| 53 | +#endif // CUDART_VERSION >= 12000 |
| 54 | + |
| 55 | + |
18 | 56 | template <typename T> |
19 | 57 | struct mma_A_I16K4 { |
20 | 58 | static_assert(sizeof(T) == 4, "bad type size"); |
@@ -119,21 +157,14 @@ struct mma_A_I16K8 { |
119 | 157 | } |
120 | 158 |
|
121 | 159 | __device__ __forceinline__ void transpose() { |
122 | | -#ifdef NEW_MMA_AVAILABLE |
123 | 160 | int * xi = (int *) x; |
124 | | - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" |
125 | | - : "+r"(xi[0]) : "r"(xi[0])); |
126 | | - int tmp = 0; |
127 | | - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" |
128 | | - : "+r"(tmp) : "r"(xi[1])); |
129 | | - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" |
130 | | - : "+r"(xi[1]) : "r"(xi[2])); |
| 161 | + xi[0] = ggml_cuda_movmatrix(xi[0]); |
| 162 | + |
| 163 | + const int tmp = ggml_cuda_movmatrix(xi[1]); |
| 164 | + xi[1] = ggml_cuda_movmatrix(xi[2]); |
131 | 165 | xi[2] = tmp; |
132 | | - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" |
133 | | - : "+r"(xi[3]) : "r"(xi[3])); |
134 | | -#else |
135 | | - NO_DEVICE_CODE; |
136 | | -#endif // NEW_MMA_AVAILABLE |
| 166 | + |
| 167 | + xi[3] = ggml_cuda_movmatrix(xi[3]); |
137 | 168 | } |
138 | 169 | }; |
139 | 170 |
|
@@ -350,16 +381,10 @@ struct mma_C_I16J8<half2> { |
350 | 381 | __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() { |
351 | 382 | mma_B_J8K8<half2> mma_B; |
352 | 383 |
|
353 | | -#ifdef NEW_MMA_AVAILABLE |
354 | 384 | int * xi = (int *) x; |
355 | 385 | int * Bxi = (int *) mma_B.x; |
356 | | - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" |
357 | | - : "+r"(Bxi[0]) : "r"(xi[0])); |
358 | | - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" |
359 | | - : "+r"(Bxi[1]) : "r"(xi[1])); |
360 | | -#else |
361 | | - NO_DEVICE_CODE; |
362 | | -#endif // NEW_MMA_AVAILABLE |
| 386 | + Bxi[0] = ggml_cuda_movmatrix(xi[0]); |
| 387 | + Bxi[1] = ggml_cuda_movmatrix(xi[1]); |
363 | 388 |
|
364 | 389 | return mma_B; |
365 | 390 | } |
@@ -417,15 +442,9 @@ struct mma_C_I16J8<float> { |
417 | 442 | mma_B.x[0] = make_half2(x[0], x[1]); |
418 | 443 | mma_B.x[1] = make_half2(x[2], x[3]); |
419 | 444 |
|
420 | | -#ifdef NEW_MMA_AVAILABLE |
421 | 445 | int * Bxi = (int *) mma_B.x; |
422 | | - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;" |
423 | | - : "+r"(Bxi[0]) : ); |
424 | | - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;" |
425 | | - : "+r"(Bxi[1]) : ); |
426 | | -#else |
427 | | - NO_DEVICE_CODE; |
428 | | -#endif // NEW_MMA_AVAILABLE |
| 446 | + Bxi[0] = ggml_cuda_movmatrix(Bxi[0]); |
| 447 | + Bxi[1] = ggml_cuda_movmatrix(Bxi[1]); |
429 | 448 |
|
430 | 449 | return mma_B; |
431 | 450 | } |
|
0 commit comments