Skip to content

Commit 6e16165

Browse files
committed
Add: Inline-PTX in C++ for WGMMA
1 parent 0207843 commit 6e16165

File tree

1 file changed

+223
-28
lines changed

1 file changed

+223
-28
lines changed

less_slow.cu

Lines changed: 223 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,15 @@ __device__ inline void binary_tops_tc_cuda_kernel( //
225225

226226
#pragma region Volta
227227

228-
__global__ void tops_f16f16_sm70tc_16x16x16_loop128_cuda_kernel() {
228+
__global__ void tops_f16f16_sm70wmma_16x16x16_loop128_cuda_kernel() {
229229
//? On Volta: 8x8x4.
230230
//? On Turing: 8x8x4 / 16x8x8 / 16x8x16.
231231
//? On Ampere: 16x8x8 / 16x8x16.
232232
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
233233
tops_tc_cuda_kernel<half, half, 16, 16, 16>();
234234
#endif
235235
}
236-
__global__ void tops_f16f32_sm70tc_16x16x16_loop128_cuda_kernel() {
236+
__global__ void tops_f16f32_sm70wmma_16x16x16_loop128_cuda_kernel() {
237237
//? On Volta: 8x8x4.
238238
//? On Turing: 8x8x4 / 16x8x8 / 16x8x16.
239239
//? On Ampere: 16x8x8 / 16x8x16.
@@ -246,22 +246,22 @@ __global__ void tops_f16f32_sm70tc_16x16x16_loop128_cuda_kernel() {
246246

247247
#pragma region Turing
248248

249-
__global__ void tops_u8i32_sm75tc_16x16x16_loop128_cuda_kernel() {
249+
__global__ void tops_u8i32_sm75wmma_16x16x16_loop128_cuda_kernel() {
250250
//? On Turing: 8x8x16.
251251
//? On Ampere: 8x8x16 / 16x8x16 / 16x8x32.
252252
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
253253
tops_tc_cuda_kernel<std::uint8_t, int32_t, 16, 16, 16>();
254254
#endif
255255
}
256-
__global__ void tops_u4i32_sm75tc_8x8x32_loop128_cuda_kernel() {
256+
__global__ void tops_u4i32_sm75wmma_8x8x32_loop128_cuda_kernel() {
257257
//! The 16x16x16 won't compile, 8x8x32 will.
258258
//? On Turing: 8x8x32.
259259
//? On Ampere: 8x8x32 / 16x8x32 / 16x8x64.
260260
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
261261
tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::u4, int32_t, 8, 8, 32>();
262262
#endif
263263
}
264-
__global__ void tops_b1i32xor_sm75tc_8x8x128_loop128_cuda_kernel() {
264+
__global__ void tops_b1i32xor_sm75wmma_8x8x128_loop128_cuda_kernel() {
265265
//! The 16x16x16 won't compile, 8x8x128 will.
266266
//? On Turing: 8x8x128.
267267
//? On Ampere: 8x8x128 / 16x8x128 / 16x8x256.
@@ -276,28 +276,28 @@ __global__ void tops_b1i32xor_sm75tc_8x8x128_loop128_cuda_kernel() {
276276

277277
#pragma region Ampere
278278

279-
__global__ void tops_bf16f32_sm80tc_16x16x16_loop128_cuda_kernel() {
279+
__global__ void tops_bf16f32_sm80wmma_16x16x16_loop128_cuda_kernel() {
280280
//? On Ampere: 16x8x8 / 16x8x16.
281281
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
282282
tops_tc_cuda_kernel<__nv_bfloat16, float, 16, 16, 16>();
283283
#endif
284284
}
285-
__global__ void tops_tf32f32_sm80tc_16x16x8_loop128_cuda_kernel() {
285+
__global__ void tops_tf32f32_sm80wmma_16x16x8_loop128_cuda_kernel() {
286286
//! The 16x16x16 won't compile, 16x16x8 will.
287287
//? On Ampere: 16x8x4.
288288
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
289289
tops_tc_cuda_kernel<nvcuda::wmma::precision::tf32, float, 16, 16, 8>();
290290
#endif
291291
}
292-
__global__ void tops_f64f64_sm80tc_8x8x4_loop128_cuda_kernel() {
292+
__global__ void tops_f64f64_sm80wmma_8x8x4_loop128_cuda_kernel() {
293293
//! The 16x16x16 won't compile, 8x8x4 will.
294294
//? On Ampere: 8x8x4.
295295
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
296296
tops_tc_cuda_kernel<double, double, 8, 8, 4>();
297297
#endif
298298
}
299299

300-
__global__ void tops_b1i32and_sm80tc_8x8x128_loop128_cuda_kernel() {
300+
__global__ void tops_b1i32and_sm80wmma_8x8x128_loop128_cuda_kernel() {
301301
//! The 16x16x16 won't compile, 8x8x128 will.
302302
//? On Ampere: 8x8x128 / 16x8x128 / 16x8x256.
303303
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
@@ -322,13 +322,15 @@ __global__ void tops_b1i32and_sm80tc_8x8x128_loop128_cuda_kernel() {
322322
* 1. They can be asynchronous, for more flexible scheduling.
323323
* 2. They can avoid accumulation, a.k.a $C = A * B$, not $C += A * B$.
324324
*
325-
* The later are vastly more complex. Just compare our old MMA signature:
325+
* The WGMMA is vastly more complex.
326+
*
327+
* Just compare our old MMA signature:
326328
* ! {wmma.mma.sync.aligned}.{row.col}.{m16n16k16}.{f32.f32} { ........ }
327329
* ? { header }.{ layout}.{ shape }.{ types } { operands }
328330
*
329331
* To the new WGMMA signature:
330-
* ! {wgmma.mm_async.sync.aligned}.{m64n64k16}.{f32.f16.f16} { ........ },{ .... }
331-
* ? { much longer header }.{ shape }.{ types } { operands },{ args }
332+
* ! {wgmma.mma_async.sync.aligned}.{m64n64k16}.{f32.f16.f16} { ........ },{ .... }
333+
* ? { much longer header }.{ shape }.{ types } { operands },{ args }
332334
*
333335
* Not only the signature and "fragment" sizes differ, but also the scheduling
334336
* approach has changed between Ampere and Hopper once again:
@@ -343,39 +345,144 @@ __global__ void tops_b1i32and_sm80tc_8x8x128_loop128_cuda_kernel() {
343345
* to perform well - there can be a significant performance penalty if you
344346
* don't upgrade your PTX!
345347
*
346-
* To simplify the logic of higher-level Linear Algebra libraries, wrapper
347-
* templates from @b CUTLASS can be used. It has a smaller component called
348-
* @b CuTe, that wraps different kinds of MMA "atoms" - primitive kernel
349-
* templates. Just for Hopper alone, there is @b 10'000 lines of different
350-
* supported shape instantiations in @b `mma_sm90.hpp`.
351-
*
352-
* We can use CuTe to abstract away the right instructions, by defining small
353-
* shared memory matrices and performing such repeated "atom" instantiations.
354-
* We can also write "inline PTX" in CUDA C++, the same way we can write
355-
* "inline assembly" on the host side C++.
356-
*
357348
* @see "Fast Matrix-Multiplication with WGMMA on NVIDIA Hopper GPUs" by Colfax:
358349
* https://research.colfax-intl.com/cutlass-tutorial-wgmma-hopper/
359350
* @see "Outperforming cuBLAS on H100: a Worklog" by Pranjal Shankhdhar:
360351
* https://cudaforfun.substack.com/p/outperforming-cublas-on-h100-a-worklog
352+
*
353+
* To make things worse, there are no `wgmma::` CUDA C++ intrinsics!
354+
* The closest thing to them is the @b CuTe low-level collection of C++
355+
* templates, wrapping raw PTX instructions into MMA @b "atoms".
356+
* Just for Hopper alone, there is @b 10'000 lines of different supported
357+
* shape instantiations in @b `mma_sm90.hpp`.
358+
*
361359
* @see CUTLASS updates: https://github.com/NVIDIA/cutlass/blob/main/CHANGELOG.md
362360
* @see CUTLASS GEMM API: https://github.com/NVIDIA/cutlass/blob/main/media/docs/gemm_api.md
363361
* @see "Deep Dive on CUTLASS Ping-Pong GEMM Kernel" by PyTorch:
364362
* https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/
365363
* @see Minimal SM90 WGMMA + TMA GEMM example in 100 lines in CUTLASS 3.5.1:
366364
* https://github.com/NVIDIA/cutlass/blob/main/examples/cute/tutorial/wgmma_sm90.cu
367-
* @see "Blackwell Cluster Launch Control" in CUTLASS docs:
368-
* https://github.com/NVIDIA/cutlass/blob/main/media/docs/blackwell_cluster_launch_control.md
365+
*
366+
* We can also write "inline PTX" in CUDA C++, the same way we can write
367+
* "inline assembly" on the host side C++.
368+
*
369+
* The instruction syntax for Warp-Group asynchronous instructions is very
370+
* different, as at least one of the operand matrices has to be in shared
371+
* memory (not registers). It's documented as in 2 variants:
372+
*
373+
* wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32
374+
* d, a-desc, b-desc, scale-d, imm-scale-a, imm-scale-b;
375+
* wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32
376+
* d, a, b-desc, scale-d, imm-scale-a, imm-scale-b;
377+
*
378+
* There is no "C" matrix involved at all, we are computing `D = A * B + D`.
379+
* The `imm-scale` parameters can be used to either negate the inputs,
380+
* or disable additive bias accumulation in the output. Both must be immediate
381+
* values. The supported shapes list is also quite exhausting and differs for
382+
* various numeric types. For half-precision floats:
383+
*
384+
* .m64n8k8, .m64n16k8, .m64n24k8, .m64n32k8,
385+
* .m64n40k8, .m64n48k8, .m64n56k8, .m64n64k8,
386+
* .m64n72k8, .m64n80k8, .m64n88k8, .m64n96k8,
387+
* .m64n104k8, .m64n112k8, .m64n120k8, .m64n128k8,
388+
* .m64n136k8, .m64n144k8, .m64n152k8, .m64n160k8,
389+
* .m64n168k8, .m64n176k8, .m64n184k8, .m64n192k8,
390+
* .m64n200k8, .m64n208k8, .m64n216k8, .m64n224k8,
391+
* .m64n232k8, .m64n240k8, .m64n248k8, .m64n256k8
392+
393+
*/
394+
#pragma region Hopper
395+
396+
/**
397+
* Ideally, both matrices A and B should be in shared memory. Both are
398+
* defined using 64-bit descriptors with the following layout:
399+
*
400+
* - 14 bits [0; 13]: start address
401+
* - 14 bits [16; 29]: leading dimension byte offset
402+
* - 14 bits [32; 45]: stride dimension byte offset
403+
* - 3 bits [49; 51]: matrix base offset, valid only for "swizzling"
404+
* - 2 bits [62; 63]: "swizzling" mode
405+
*
406+
* The matrix layout in WGMMA can be normal or transposed, but its named
407+
* differently. Non-Transposed for A and B is called K-Major. The Transposed
408+
* variant is called M-Major for A and N-Major for B.
409+
*
410+
* The matrices in the shared memory are made up of one or more "swizzle
411+
* layout atom". The exact layout of these swizzle atoms depends on the
412+
* swizzling mode, swizzle-atomicity, and the leading dimension.
413+
*
414+
* Swizzling defines the order of the elements and can have 4 possible values:
415+
*
416+
* 0: no "swizzling" at all
417+
* 1: a 128-byte "swizzle" with a 1024 byte offset of a repeating pattern
418+
* 2: a 64-byte "swizzle" with a 512 byte offset of a repeating pattern
419+
* 3: a 32-byte "swizzle" with a 256 byte offset of a repeating pattern
420+
*
421+
* Here is how that logic is packed together:
369422
*/
370423
__device__ std::uint64_t wgmma_descriptor( //
371424
std::uint64_t address, //
372425
std::uint64_t leading_offset, std::uint64_t stride_offset, std::uint64_t base_offset, //
373426
std::uint64_t swizzle) {
427+
//! One of the most counter-intuitive things is how those matrix descriptors are composed.
428+
//! All fo the strides are in bytes, but divided by 16 (same as right-sift by four).
374429
return ((address & 0x3FFFF) >> 4) | ((leading_offset >> 4) << 16) | ((stride_offset >> 4) << 32) |
375430
(base_offset << 49) | (swizzle << 62);
376431
}
377432

433+
__device__ void wgmma_f16f32_64x256x16(float r[128], std::uint64_t a_descriptor, std::uint64_t b_descriptor) {
434+
//! Interestingly, there are 2 variants of this instruction:
435+
//! 1. Both arguments are in shared memory, in which case 2 immediate values
436+
//! can be used to transpose the inputs.
437+
//! 2. One argument is in shared memory, and the other one is in the registers,
438+
//! in which case only one can be transposed, and only one immediate value
439+
//! for that can be supplied!
440+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
441+
asm volatile( //
442+
"wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 "
443+
"{"
444+
"%0, %1, %2, %3, %4, %5, %6, %7, "
445+
"%8, %9, %10, %11, %12, %13, %14, %15, "
446+
"%16, %17, %18, %19, %20, %21, %22, %23, "
447+
"%24, %25, %26, %27, %28, %29, %30, %31, "
448+
"%32, %33, %34, %35, %36, %37, %38, %39, "
449+
"%40, %41, %42, %43, %44, %45, %46, %47, "
450+
"%48, %49, %50, %51, %52, %53, %54, %55, "
451+
"%56, %57, %58, %59, %60, %61, %62, %63, "
452+
"%64, %65, %66, %67, %68, %69, %70, %71, "
453+
"%72, %73, %74, %75, %76, %77, %78, %79, "
454+
"%80, %81, %82, %83, %84, %85, %86, %87, "
455+
"%88, %89, %90, %91, %92, %93, %94, %95, "
456+
"%96, %97, %98, %99, %100, %101, %102, %103, "
457+
"%104, %105, %106, %107, %108, %109, %110, %111, "
458+
"%112, %113, %114, %115, %116, %117, %118, %119, "
459+
"%120, %121, %122, %123, %124, %125, %126, %127"
460+
"}, "
461+
"%128, %129, "
462+
"1, 1, 1, 0, 0;"
463+
: "=f"(r[0]), "=f"(r[1]), "=f"(r[2]), "=f"(r[3]), "=f"(r[4]), "=f"(r[5]), "=f"(r[6]), "=f"(r[7]), "=f"(r[8]),
464+
"=f"(r[9]), "=f"(r[10]), "=f"(r[11]), "=f"(r[12]), "=f"(r[13]), "=f"(r[14]), "=f"(r[15]), "=f"(r[16]),
465+
"=f"(r[17]), "=f"(r[18]), "=f"(r[19]), "=f"(r[20]), "=f"(r[21]), "=f"(r[22]), "=f"(r[23]), "=f"(r[24]),
466+
"=f"(r[25]), "=f"(r[26]), "=f"(r[27]), "=f"(r[28]), "=f"(r[29]), "=f"(r[30]), "=f"(r[31]), "=f"(r[32]),
467+
"=f"(r[33]), "=f"(r[34]), "=f"(r[35]), "=f"(r[36]), "=f"(r[37]), "=f"(r[38]), "=f"(r[39]), "=f"(r[40]),
468+
"=f"(r[41]), "=f"(r[42]), "=f"(r[43]), "=f"(r[44]), "=f"(r[45]), "=f"(r[46]), "=f"(r[47]), "=f"(r[48]),
469+
"=f"(r[49]), "=f"(r[50]), "=f"(r[51]), "=f"(r[52]), "=f"(r[53]), "=f"(r[54]), "=f"(r[55]), "=f"(r[56]),
470+
"=f"(r[57]), "=f"(r[58]), "=f"(r[59]), "=f"(r[60]), "=f"(r[61]), "=f"(r[62]), "=f"(r[63]), "=f"(r[64]),
471+
"=f"(r[65]), "=f"(r[66]), "=f"(r[67]), "=f"(r[68]), "=f"(r[69]), "=f"(r[70]), "=f"(r[71]), "=f"(r[72]),
472+
"=f"(r[73]), "=f"(r[74]), "=f"(r[75]), "=f"(r[76]), "=f"(r[77]), "=f"(r[78]), "=f"(r[79]), "=f"(r[80]),
473+
"=f"(r[81]), "=f"(r[82]), "=f"(r[83]), "=f"(r[84]), "=f"(r[85]), "=f"(r[86]), "=f"(r[87]), "=f"(r[88]),
474+
"=f"(r[89]), "=f"(r[90]), "=f"(r[91]), "=f"(r[92]), "=f"(r[93]), "=f"(r[94]), "=f"(r[95]), "=f"(r[96]),
475+
"=f"(r[97]), "=f"(r[98]), "=f"(r[99]), "=f"(r[100]), "=f"(r[101]), "=f"(r[102]), "=f"(r[103]), "=f"(r[104]),
476+
"=f"(r[105]), "=f"(r[106]), "=f"(r[107]), "=f"(r[108]), "=f"(r[109]), "=f"(r[110]), "=f"(r[111]),
477+
"=f"(r[112]), "=f"(r[113]), "=f"(r[114]), "=f"(r[115]), "=f"(r[116]), "=f"(r[117]), "=f"(r[118]),
478+
"=f"(r[119]), "=f"(r[120]), "=f"(r[121]), "=f"(r[122]), "=f"(r[123]), "=f"(r[124]), "=f"(r[125]),
479+
"=f"(r[126]), "=f"(r[127])
480+
: "l"(a_descriptor), "l"(b_descriptor));
481+
#endif
482+
}
483+
378484
__device__ void wgmma_bf16f32_64x256x16(float r[128], std::uint64_t a_descriptor, std::uint64_t b_descriptor) {
485+
// The `bf16` instructions are almost identical to `f16`.
379486
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
380487
asm volatile( //
381488
"wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 "
@@ -420,6 +527,53 @@ __device__ void wgmma_bf16f32_64x256x16(float r[128], std::uint64_t a_descriptor
420527
#endif
421528
}
422529

530+
__device__ void wgmma_tf32f32_64x256x16(float r[128], std::uint64_t a_descriptor, std::uint64_t b_descriptor) {
531+
//! Unlike the `f16` and `bf16` instructions, the `tf32` has fewer operands,
532+
//! and can't transpose the input matrices!
533+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
534+
asm volatile( //
535+
"wgmma.mma_async.sync.aligned.m64n256k16.f32.tf32.tf32 "
536+
"{"
537+
"%0, %1, %2, %3, %4, %5, %6, %7, "
538+
"%8, %9, %10, %11, %12, %13, %14, %15, "
539+
"%16, %17, %18, %19, %20, %21, %22, %23, "
540+
"%24, %25, %26, %27, %28, %29, %30, %31, "
541+
"%32, %33, %34, %35, %36, %37, %38, %39, "
542+
"%40, %41, %42, %43, %44, %45, %46, %47, "
543+
"%48, %49, %50, %51, %52, %53, %54, %55, "
544+
"%56, %57, %58, %59, %60, %61, %62, %63, "
545+
"%64, %65, %66, %67, %68, %69, %70, %71, "
546+
"%72, %73, %74, %75, %76, %77, %78, %79, "
547+
"%80, %81, %82, %83, %84, %85, %86, %87, "
548+
"%88, %89, %90, %91, %92, %93, %94, %95, "
549+
"%96, %97, %98, %99, %100, %101, %102, %103, "
550+
"%104, %105, %106, %107, %108, %109, %110, %111, "
551+
"%112, %113, %114, %115, %116, %117, %118, %119, "
552+
"%120, %121, %122, %123, %124, %125, %126, %127"
553+
"}, "
554+
"%128, %129, "
555+
"1, 1, 1;"
556+
: "=f"(r[0]), "=f"(r[1]), "=f"(r[2]), "=f"(r[3]), "=f"(r[4]), "=f"(r[5]), "=f"(r[6]), "=f"(r[7]), "=f"(r[8]),
557+
"=f"(r[9]), "=f"(r[10]), "=f"(r[11]), "=f"(r[12]), "=f"(r[13]), "=f"(r[14]), "=f"(r[15]), "=f"(r[16]),
558+
"=f"(r[17]), "=f"(r[18]), "=f"(r[19]), "=f"(r[20]), "=f"(r[21]), "=f"(r[22]), "=f"(r[23]), "=f"(r[24]),
559+
"=f"(r[25]), "=f"(r[26]), "=f"(r[27]), "=f"(r[28]), "=f"(r[29]), "=f"(r[30]), "=f"(r[31]), "=f"(r[32]),
560+
"=f"(r[33]), "=f"(r[34]), "=f"(r[35]), "=f"(r[36]), "=f"(r[37]), "=f"(r[38]), "=f"(r[39]), "=f"(r[40]),
561+
"=f"(r[41]), "=f"(r[42]), "=f"(r[43]), "=f"(r[44]), "=f"(r[45]), "=f"(r[46]), "=f"(r[47]), "=f"(r[48]),
562+
"=f"(r[49]), "=f"(r[50]), "=f"(r[51]), "=f"(r[52]), "=f"(r[53]), "=f"(r[54]), "=f"(r[55]), "=f"(r[56]),
563+
"=f"(r[57]), "=f"(r[58]), "=f"(r[59]), "=f"(r[60]), "=f"(r[61]), "=f"(r[62]), "=f"(r[63]), "=f"(r[64]),
564+
"=f"(r[65]), "=f"(r[66]), "=f"(r[67]), "=f"(r[68]), "=f"(r[69]), "=f"(r[70]), "=f"(r[71]), "=f"(r[72]),
565+
"=f"(r[73]), "=f"(r[74]), "=f"(r[75]), "=f"(r[76]), "=f"(r[77]), "=f"(r[78]), "=f"(r[79]), "=f"(r[80]),
566+
"=f"(r[81]), "=f"(r[82]), "=f"(r[83]), "=f"(r[84]), "=f"(r[85]), "=f"(r[86]), "=f"(r[87]), "=f"(r[88]),
567+
"=f"(r[89]), "=f"(r[90]), "=f"(r[91]), "=f"(r[92]), "=f"(r[93]), "=f"(r[94]), "=f"(r[95]), "=f"(r[96]),
568+
"=f"(r[97]), "=f"(r[98]), "=f"(r[99]), "=f"(r[100]), "=f"(r[101]), "=f"(r[102]), "=f"(r[103]), "=f"(r[104]),
569+
"=f"(r[105]), "=f"(r[106]), "=f"(r[107]), "=f"(r[108]), "=f"(r[109]), "=f"(r[110]), "=f"(r[111]),
570+
"=f"(r[112]), "=f"(r[113]), "=f"(r[114]), "=f"(r[115]), "=f"(r[116]), "=f"(r[117]), "=f"(r[118]),
571+
"=f"(r[119]), "=f"(r[120]), "=f"(r[121]), "=f"(r[122]), "=f"(r[123]), "=f"(r[124]), "=f"(r[125]),
572+
"=f"(r[126]), "=f"(r[127])
573+
: "l"(a_descriptor), "l"(b_descriptor));
574+
#endif
575+
}
576+
423577
__device__ void wgmma_commit_group() {
424578
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
425579
asm volatile("wgmma.commit_group.sync.aligned;");
@@ -432,9 +586,25 @@ __device__ void wgmma_wait_group() {
432586
#endif
433587
}
434588

435-
__global__ void tops_bf16f32_sm90tc_64x256x16_loop128_cuda_kernel() {
436-
__shared__ __nv_bfloat16 a_shared[64][16];
437-
__shared__ __nv_bfloat16 b_shared[256][16];
589+
__global__ void tops_f16f32_sm90wgmma_64x256x16_loop128_cuda_kernel() {
590+
// 64x256x16 is the largest tile size for `f16` supported on Hopper.
591+
// We can use `half` for type, but `uint16_t` is more portable.
592+
__shared__ std::uint16_t a_shared[64][16];
593+
__shared__ std::uint16_t b_shared[256][16];
594+
595+
float c_registers[128] = {0.0f};
596+
std::uint64_t a_descriptor = wgmma_descriptor((std::uint64_t)a_shared, 128, 256, 0, 0);
597+
std::uint64_t b_descriptor = wgmma_descriptor((std::uint64_t)b_shared, 128 * 256 / 8, 128, 0, 0);
598+
for (int i = 0; i != 128; ++i) wgmma_bf16f32_64x256x16(c_registers, a_descriptor, b_descriptor);
599+
wgmma_commit_group();
600+
wgmma_wait_group();
601+
}
602+
603+
__global__ void tops_bf16f32_sm90wgmma_64x256x16_loop128_cuda_kernel() {
604+
// 64x256x16 is the largest tile size for `bf16` supported on Hopper.
605+
// We can use `__nv_bfloat16` for type, but `uint16_t` is more portable.
606+
__shared__ std::uint16_t a_shared[64][16];
607+
__shared__ std::uint16_t b_shared[256][16];
438608

439609
float c_registers[128] = {0.0f};
440610
std::uint64_t a_descriptor = wgmma_descriptor((std::uint64_t)a_shared, 128, 256, 0, 0);
@@ -443,3 +613,28 @@ __global__ void tops_bf16f32_sm90tc_64x256x16_loop128_cuda_kernel() {
443613
wgmma_commit_group();
444614
wgmma_wait_group();
445615
}
616+
617+
__global__ void tops_tf32f32_sm90wgmma_64x256x16_loop128_cuda_kernel() {
618+
// 64x256x16 is the largest tile size for `tf32` supported on Hopper.
619+
// Four-byte representations should be used for storage. Each entry will
620+
// shifted right by 13 bits before multiplication.
621+
__shared__ std::uint32_t a_shared[64][16];
622+
__shared__ std::uint32_t b_shared[256][16];
623+
624+
// TODO: Unlike smaller 2-byte floats, the stride sizes will be different here.
625+
float c_registers[128] = {0.0f};
626+
std::uint64_t a_descriptor = wgmma_descriptor((std::uint64_t)a_shared, 128, 256, 0, 0);
627+
std::uint64_t b_descriptor = wgmma_descriptor((std::uint64_t)b_shared, 128 * 256 / 8, 128, 0, 0);
628+
for (int i = 0; i != 128; ++i) wgmma_bf16f32_64x256x16(c_registers, a_descriptor, b_descriptor);
629+
wgmma_commit_group();
630+
wgmma_wait_group();
631+
}
632+
633+
#pragma endregion
634+
635+
/**
636+
*
637+
* @see "Blackwell Cluster Launch Control" in CUTLASS docs:
638+
* https://github.com/NVIDIA/cutlass/blob/main/media/docs/blackwell_cluster_launch_control.md
639+
*
640+
*/

0 commit comments

Comments
 (0)