Skip to content

Commit 9d27c38

Browse files
committed
Use namespace alias wmma instead of lots of ifdefs.
1 parent 828577a commit 9d27c38

File tree

1 file changed

+18
-70
lines changed

1 file changed

+18
-70
lines changed

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 18 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#ifdef FP16_MMA_AVAILABLE
1010
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
#include <mma.h>
12+
namespace wmma = nvcuda::wmma;
1213
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1314
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1415
#include <rocwmma/rocwmma.hpp>
16+
namespace wmma = rocwmma;
1517
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1618
#endif // FP16_MMA_AVAILABLE
1719

@@ -73,19 +75,11 @@ static __global__ void flash_attn_ext_f16(
7375
constexpr int frag_m = ncols == 8 ? 32 : 16;
7476
constexpr int frag_n = ncols == 8 ? 8 : 16;
7577
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
76-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
77-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
78-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
79-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
80-
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
81-
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
82-
#else
83-
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::row_major> frag_a_K;
84-
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::col_major> frag_a_V;
85-
typedef rocwmma::fragment<rocwmma::matrix_b, frag_m, frag_n, 16, half, rocwmma::col_major> frag_b;
86-
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
87-
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
88-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
78+
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
79+
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
80+
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
81+
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
82+
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
8983

9084
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
9185
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -175,11 +169,7 @@ static __global__ void flash_attn_ext_f16(
175169
for (int i0 = 0; i0 < D; i0 += 16) {
176170
#pragma unroll
177171
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
178-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
179-
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
180-
#else
181-
rocwmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
182-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
172+
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
183173
}
184174
}
185175

@@ -193,36 +183,20 @@ static __global__ void flash_attn_ext_f16(
193183
frag_c_KQ KQ_c[ncols/frag_n];
194184
#pragma unroll
195185
for (int j = 0; j < ncols/frag_n; ++j) {
196-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
197-
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
198-
#else
199-
rocwmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
200-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
186+
wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
201187
}
202188
#pragma unroll
203189
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
204190
frag_a_K K_a;
205-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
206-
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
207-
#else
208-
rocwmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
209-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
191+
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
210192
#pragma unroll
211193
for (int j = 0; j < ncols/frag_n; ++j) {
212-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
213-
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
214-
#else
215-
rocwmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
216-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
194+
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
217195
}
218196
}
219197
#pragma unroll
220198
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
221-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
222-
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
223-
#else
224-
rocwmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, rocwmma::mem_col_major);
225-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
199+
wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
226200
}
227201
}
228202

@@ -341,17 +315,10 @@ static __global__ void flash_attn_ext_f16(
341315
#pragma unroll
342316
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
343317
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
344-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
345-
nvcuda::wmma::load_matrix_sync(
318+
wmma::load_matrix_sync(
346319
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
347320
KQ + j0*(kqar*kqs_padded) + k,
348321
kqar*kqs_padded);
349-
#else
350-
rocwmma::load_matrix_sync(
351-
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
352-
KQ + j0*(kqar*kqs_padded) + k,
353-
kqar*kqs_padded);
354-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
355322
}
356323
}
357324

@@ -360,30 +327,18 @@ static __global__ void flash_attn_ext_f16(
360327
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
361328
#pragma unroll
362329
for (int j = 0; j < ncols/frag_n; ++j) {
363-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
364-
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
365-
#else
366-
rocwmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
367-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
330+
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
368331
}
369332

370333
#pragma unroll
371334
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
372335
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
373336

374337
frag_a_V v_a;
375-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
376-
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
377-
#else
378-
rocwmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
379-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
338+
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
380339
#pragma unroll
381340
for (int j = 0; j < ncols/frag_n; ++j) {
382-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
383-
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
384-
#else
385-
rocwmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
386-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
341+
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
387342
}
388343
}
389344
}
@@ -395,17 +350,10 @@ static __global__ void flash_attn_ext_f16(
395350
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
396351
#pragma unroll
397352
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
398-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
399-
nvcuda::wmma::store_matrix_sync(
353+
wmma::store_matrix_sync(
400354
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
401355
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
402-
D_padded, nvcuda::wmma::mem_col_major);
403-
#else
404-
rocwmma::store_matrix_sync(
405-
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
406-
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
407-
D_padded, rocwmma::mem_col_major);
408-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
356+
D_padded, wmma::mem_col_major);
409357
}
410358
}
411359

0 commit comments

Comments
 (0)