@@ -225,15 +225,15 @@ __device__ inline void binary_tops_tc_cuda_kernel( //
225225
226226#pragma region Volta
227227
228- __global__ void tops_f16f16_sm70tc_16x16x16_loop128_cuda_kernel () {
228+ __global__ void tops_f16f16_sm70wmma_16x16x16_loop128_cuda_kernel () {
229229 // ? On Volta: 8x8x4.
230230 // ? On Turing: 8x8x4 / 16x8x8 / 16x8x16.
231231 // ? On Ampere: 16x8x8 / 16x8x16.
232232#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
233233 tops_tc_cuda_kernel<half, half, 16 , 16 , 16 >();
234234#endif
235235}
236- __global__ void tops_f16f32_sm70tc_16x16x16_loop128_cuda_kernel () {
236+ __global__ void tops_f16f32_sm70wmma_16x16x16_loop128_cuda_kernel () {
237237 // ? On Volta: 8x8x4.
238238 // ? On Turing: 8x8x4 / 16x8x8 / 16x8x16.
239239 // ? On Ampere: 16x8x8 / 16x8x16.
@@ -246,22 +246,22 @@ __global__ void tops_f16f32_sm70tc_16x16x16_loop128_cuda_kernel() {
246246
247247#pragma region Turing
248248
249- __global__ void tops_u8i32_sm75tc_16x16x16_loop128_cuda_kernel () {
249+ __global__ void tops_u8i32_sm75wmma_16x16x16_loop128_cuda_kernel () {
250250 // ? On Turing: 8x8x16.
251251 // ? On Ampere: 8x8x16 / 16x8x16 / 16x8x32.
252252#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
253253 tops_tc_cuda_kernel<std::uint8_t , int32_t , 16 , 16 , 16 >();
254254#endif
255255}
256- __global__ void tops_u4i32_sm75tc_8x8x32_loop128_cuda_kernel () {
256+ __global__ void tops_u4i32_sm75wmma_8x8x32_loop128_cuda_kernel () {
257257 // ! The 16x16x16 won't compile, 8x8x32 will.
258258 // ? On Turing: 8x8x32.
259259 // ? On Ampere: 8x8x32 / 16x8x32 / 16x8x64.
260260#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
261261 tops_tc_cuda_kernel<nvcuda::wmma::experimental::precision::u4, int32_t , 8 , 8 , 32 >();
262262#endif
263263}
264- __global__ void tops_b1i32xor_sm75tc_8x8x128_loop128_cuda_kernel () {
264+ __global__ void tops_b1i32xor_sm75wmma_8x8x128_loop128_cuda_kernel () {
265265 // ! The 16x16x16 won't compile, 8x8x128 will.
266266 // ? On Turing: 8x8x128.
267267 // ? On Ampere: 8x8x128 / 16x8x128 / 16x8x256.
@@ -276,28 +276,28 @@ __global__ void tops_b1i32xor_sm75tc_8x8x128_loop128_cuda_kernel() {
276276
277277#pragma region Ampere
278278
279- __global__ void tops_bf16f32_sm80tc_16x16x16_loop128_cuda_kernel () {
279+ __global__ void tops_bf16f32_sm80wmma_16x16x16_loop128_cuda_kernel () {
280280 // ? On Ampere: 16x8x8 / 16x8x16.
281281#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
282282 tops_tc_cuda_kernel<__nv_bfloat16, float , 16 , 16 , 16 >();
283283#endif
284284}
285- __global__ void tops_tf32f32_sm80tc_16x16x8_loop128_cuda_kernel () {
285+ __global__ void tops_tf32f32_sm80wmma_16x16x8_loop128_cuda_kernel () {
286286 // ! The 16x16x16 won't compile, 16x16x8 will.
287287 // ? On Ampere: 16x8x4.
288288#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
289289 tops_tc_cuda_kernel<nvcuda::wmma::precision::tf32, float , 16 , 16 , 8 >();
290290#endif
291291}
292- __global__ void tops_f64f64_sm80tc_8x8x4_loop128_cuda_kernel () {
292+ __global__ void tops_f64f64_sm80wmma_8x8x4_loop128_cuda_kernel () {
293293 // ! The 16x16x16 won't compile, 8x8x4 will.
294294 // ? On Ampere: 8x8x4.
295295#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
296296 tops_tc_cuda_kernel<double , double , 8 , 8 , 4 >();
297297#endif
298298}
299299
300- __global__ void tops_b1i32and_sm80tc_8x8x128_loop128_cuda_kernel () {
300+ __global__ void tops_b1i32and_sm80wmma_8x8x128_loop128_cuda_kernel () {
301301 // ! The 16x16x16 won't compile, 8x8x128 will.
302302 // ? On Ampere: 8x8x128 / 16x8x128 / 16x8x256.
303303#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
@@ -322,13 +322,15 @@ __global__ void tops_b1i32and_sm80tc_8x8x128_loop128_cuda_kernel() {
322322 * 1. They can be asynchronous, for more flexible scheduling.
323323 * 2. They can avoid accumulation, a.k.a $C = A * B$, not $C += A * B$.
324324 *
325- * The later are vastly more complex. Just compare our old MMA signature:
325+ * The WGMMA is vastly more complex.
326+ *
327+ * Just compare our old MMA signature:
326328 * ! {wmma.mma.sync.aligned}.{row.col}.{m16n16k16}.{f32.f32} { ........ }
327329 * ? { header }.{ layout}.{ shape }.{ types } { operands }
328330 *
329331 * To the new WGMMA signature:
330- * ! {wgmma.mm_async .sync.aligned}.{m64n64k16}.{f32.f16.f16} { ........ },{ .... }
331- * ? { much longer header }.{ shape }.{ types } { operands },{ args }
332+ * ! {wgmma.mma_async .sync.aligned}.{m64n64k16}.{f32.f16.f16} { ........ },{ .... }
333+ * ? { much longer header }.{ shape }.{ types } { operands },{ args }
332334 *
333335 * Not only the signature and "fragment" sizes differ, but also the scheduling
334336 * approach has changed between Ampere and Hopper once again:
@@ -343,39 +345,144 @@ __global__ void tops_b1i32and_sm80tc_8x8x128_loop128_cuda_kernel() {
343345 * to perform well - there can be a significant performance penalty if you
344346 * don't upgrade your PTX!
345347 *
346- * To simplify the logic of higher-level Linear Algebra libraries, wrapper
347- * templates from @b CUTLASS can be used. It has a smaller component called
348- * @b CuTe, that wraps different kinds of MMA "atoms" - primitive kernel
349- * templates. Just for Hopper alone, there is @b 10'000 lines of different
350- * supported shape instantiations in @b `mma_sm90.hpp`.
351- *
352- * We can use CuTe to abstract away the right instructions, by defining small
353- * shared memory matrices and performing such repeated "atom" instantiations.
354- * We can also write "inline PTX" in CUDA C++, the same way we can write
355- * "inline assembly" on the host side C++.
356- *
357348 * @see "Fast Matrix-Multiplication with WGMMA on NVIDIA Hopper GPUs" by Colfax:
358349 * https://research.colfax-intl.com/cutlass-tutorial-wgmma-hopper/
359350 * @see "Outperforming cuBLAS on H100: a Worklog" by Pranjal Shankhdhar:
360351 * https://cudaforfun.substack.com/p/outperforming-cublas-on-h100-a-worklog
352+ *
353+ * To make things worse, there are no `wgmma::` CUDA C++ intrinsics!
354+ * The closest thing to them is the @b CuTe low-level collection of C++
355+ * templates, wrapping raw PTX instructions into MMA @b "atoms".
356+ * Just for Hopper alone, there is @b 10'000 lines of different supported
357+ * shape instantiations in @b `mma_sm90.hpp`.
358+ *
361359 * @see CUTLASS updates: https://github.com/NVIDIA/cutlass/blob/main/CHANGELOG.md
362360 * @see CUTLASS GEMM API: https://github.com/NVIDIA/cutlass/blob/main/media/docs/gemm_api.md
363361 * @see "Deep Dive on CUTLASS Ping-Pong GEMM Kernel" by PyTorch:
364362 * https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/
365363 * @see Minimal SM90 WGMMA + TMA GEMM example in 100 lines in CUTLASS 3.5.1:
366364 * https://github.com/NVIDIA/cutlass/blob/main/examples/cute/tutorial/wgmma_sm90.cu
367- * @see "Blackwell Cluster Launch Control" in CUTLASS docs:
368- * https://github.com/NVIDIA/cutlass/blob/main/media/docs/blackwell_cluster_launch_control.md
365+ *
366+ * We can also write "inline PTX" in CUDA C++, the same way we can write
367+ * "inline assembly" on the host side C++.
368+ *
369+ * The instruction syntax for Warp-Group asynchronous instructions is very
370+ * different, as at least one of the operand matrices has to be in shared
371+ * memory (not registers). It's documented as in 2 variants:
372+ *
373+ * wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32
374+ * d, a-desc, b-desc, scale-d, imm-scale-a, imm-scale-b;
375+ * wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32
376+ * d, a, b-desc, scale-d, imm-scale-a, imm-scale-b;
377+ *
378+ * There is no "C" matrix involved at all, we are computing `D = A * B + D`.
379+ * The `imm-scale` parameters can be used to either negate the inputs,
380+ * or disable additive bias accumulation in the output. Both must be immediate
381+ * values. The supported shapes list is also quite exhausting and differs for
382+ * various numeric types. For half-precision floats:
383+ *
384+ * .m64n8k8, .m64n16k8, .m64n24k8, .m64n32k8,
385+ * .m64n40k8, .m64n48k8, .m64n56k8, .m64n64k8,
386+ * .m64n72k8, .m64n80k8, .m64n88k8, .m64n96k8,
387+ * .m64n104k8, .m64n112k8, .m64n120k8, .m64n128k8,
388+ * .m64n136k8, .m64n144k8, .m64n152k8, .m64n160k8,
389+ * .m64n168k8, .m64n176k8, .m64n184k8, .m64n192k8,
390+ * .m64n200k8, .m64n208k8, .m64n216k8, .m64n224k8,
391+ * .m64n232k8, .m64n240k8, .m64n248k8, .m64n256k8
392+
393+ */
394+ #pragma region Hopper
395+
396+ /* *
397+ * Ideally, both matrices A and B should be in shared memory. Both are
398+ * defined using 64-bit descriptors with the following layout:
399+ *
400+ * - 14 bits [0; 13]: start address
401+ * - 14 bits [16; 29]: leading dimension byte offset
402+ * - 14 bits [32; 45]: stride dimension byte offset
403+ * - 3 bits [49; 51]: matrix base offset, valid only for "swizzling"
404+ * - 2 bits [62; 63]: "swizzling" mode
405+ *
406+ * The matrix layout in WGMMA can be normal or transposed, but its named
407+ * differently. Non-Transposed for A and B is called K-Major. The Transposed
408+ * variant is called M-Major for A and N-Major for B.
409+ *
410+ * The matrices in the shared memory are made up of one or more "swizzle
411+ * layout atom". The exact layout of these swizzle atoms depends on the
412+ * swizzling mode, swizzle-atomicity, and the leading dimension.
413+ *
414+ * Swizzling defines the order of the elements and can have 4 possible values:
415+ *
416+ * 0: no "swizzling" at all
417+ * 1: a 128-byte "swizzle" with a 1024 byte offset of a repeating pattern
418+ * 2: a 64-byte "swizzle" with a 512 byte offset of a repeating pattern
419+ * 3: a 32-byte "swizzle" with a 256 byte offset of a repeating pattern
420+ *
421+ * Here is how that logic is packed together:
369422 */
370423__device__ std::uint64_t wgmma_descriptor ( //
371424 std::uint64_t address, //
372425 std::uint64_t leading_offset, std::uint64_t stride_offset, std::uint64_t base_offset, //
373426 std::uint64_t swizzle) {
427+ // ! One of the most counter-intuitive things is how those matrix descriptors are composed.
428+ // ! All fo the strides are in bytes, but divided by 16 (same as right-sift by four).
374429 return ((address & 0x3FFFF ) >> 4 ) | ((leading_offset >> 4 ) << 16 ) | ((stride_offset >> 4 ) << 32 ) |
375430 (base_offset << 49 ) | (swizzle << 62 );
376431}
377432
433+ __device__ void wgmma_f16f32_64x256x16 (float r[128 ], std::uint64_t a_descriptor, std::uint64_t b_descriptor) {
434+ // ! Interestingly, there are 2 variants of this instruction:
435+ // ! 1. Both arguments are in shared memory, in which case 2 immediate values
436+ // ! can be used to transpose the inputs.
437+ // ! 2. One argument is in shared memory, and the other one is in the registers,
438+ // ! in which case only one can be transposed, and only one immediate value
439+ // ! for that can be supplied!
440+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
441+ asm volatile ( //
442+ " wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 "
443+ " {"
444+ " %0, %1, %2, %3, %4, %5, %6, %7, "
445+ " %8, %9, %10, %11, %12, %13, %14, %15, "
446+ " %16, %17, %18, %19, %20, %21, %22, %23, "
447+ " %24, %25, %26, %27, %28, %29, %30, %31, "
448+ " %32, %33, %34, %35, %36, %37, %38, %39, "
449+ " %40, %41, %42, %43, %44, %45, %46, %47, "
450+ " %48, %49, %50, %51, %52, %53, %54, %55, "
451+ " %56, %57, %58, %59, %60, %61, %62, %63, "
452+ " %64, %65, %66, %67, %68, %69, %70, %71, "
453+ " %72, %73, %74, %75, %76, %77, %78, %79, "
454+ " %80, %81, %82, %83, %84, %85, %86, %87, "
455+ " %88, %89, %90, %91, %92, %93, %94, %95, "
456+ " %96, %97, %98, %99, %100, %101, %102, %103, "
457+ " %104, %105, %106, %107, %108, %109, %110, %111, "
458+ " %112, %113, %114, %115, %116, %117, %118, %119, "
459+ " %120, %121, %122, %123, %124, %125, %126, %127"
460+ " }, "
461+ " %128, %129, "
462+ " 1, 1, 1, 0, 0;"
463+ : " =f" (r[0 ]), " =f" (r[1 ]), " =f" (r[2 ]), " =f" (r[3 ]), " =f" (r[4 ]), " =f" (r[5 ]), " =f" (r[6 ]), " =f" (r[7 ]), " =f" (r[8 ]),
464+ " =f" (r[9 ]), " =f" (r[10 ]), " =f" (r[11 ]), " =f" (r[12 ]), " =f" (r[13 ]), " =f" (r[14 ]), " =f" (r[15 ]), " =f" (r[16 ]),
465+ " =f" (r[17 ]), " =f" (r[18 ]), " =f" (r[19 ]), " =f" (r[20 ]), " =f" (r[21 ]), " =f" (r[22 ]), " =f" (r[23 ]), " =f" (r[24 ]),
466+ " =f" (r[25 ]), " =f" (r[26 ]), " =f" (r[27 ]), " =f" (r[28 ]), " =f" (r[29 ]), " =f" (r[30 ]), " =f" (r[31 ]), " =f" (r[32 ]),
467+ " =f" (r[33 ]), " =f" (r[34 ]), " =f" (r[35 ]), " =f" (r[36 ]), " =f" (r[37 ]), " =f" (r[38 ]), " =f" (r[39 ]), " =f" (r[40 ]),
468+ " =f" (r[41 ]), " =f" (r[42 ]), " =f" (r[43 ]), " =f" (r[44 ]), " =f" (r[45 ]), " =f" (r[46 ]), " =f" (r[47 ]), " =f" (r[48 ]),
469+ " =f" (r[49 ]), " =f" (r[50 ]), " =f" (r[51 ]), " =f" (r[52 ]), " =f" (r[53 ]), " =f" (r[54 ]), " =f" (r[55 ]), " =f" (r[56 ]),
470+ " =f" (r[57 ]), " =f" (r[58 ]), " =f" (r[59 ]), " =f" (r[60 ]), " =f" (r[61 ]), " =f" (r[62 ]), " =f" (r[63 ]), " =f" (r[64 ]),
471+ " =f" (r[65 ]), " =f" (r[66 ]), " =f" (r[67 ]), " =f" (r[68 ]), " =f" (r[69 ]), " =f" (r[70 ]), " =f" (r[71 ]), " =f" (r[72 ]),
472+ " =f" (r[73 ]), " =f" (r[74 ]), " =f" (r[75 ]), " =f" (r[76 ]), " =f" (r[77 ]), " =f" (r[78 ]), " =f" (r[79 ]), " =f" (r[80 ]),
473+ " =f" (r[81 ]), " =f" (r[82 ]), " =f" (r[83 ]), " =f" (r[84 ]), " =f" (r[85 ]), " =f" (r[86 ]), " =f" (r[87 ]), " =f" (r[88 ]),
474+ " =f" (r[89 ]), " =f" (r[90 ]), " =f" (r[91 ]), " =f" (r[92 ]), " =f" (r[93 ]), " =f" (r[94 ]), " =f" (r[95 ]), " =f" (r[96 ]),
475+ " =f" (r[97 ]), " =f" (r[98 ]), " =f" (r[99 ]), " =f" (r[100 ]), " =f" (r[101 ]), " =f" (r[102 ]), " =f" (r[103 ]), " =f" (r[104 ]),
476+ " =f" (r[105 ]), " =f" (r[106 ]), " =f" (r[107 ]), " =f" (r[108 ]), " =f" (r[109 ]), " =f" (r[110 ]), " =f" (r[111 ]),
477+ " =f" (r[112 ]), " =f" (r[113 ]), " =f" (r[114 ]), " =f" (r[115 ]), " =f" (r[116 ]), " =f" (r[117 ]), " =f" (r[118 ]),
478+ " =f" (r[119 ]), " =f" (r[120 ]), " =f" (r[121 ]), " =f" (r[122 ]), " =f" (r[123 ]), " =f" (r[124 ]), " =f" (r[125 ]),
479+ " =f" (r[126 ]), " =f" (r[127 ])
480+ : " l" (a_descriptor), " l" (b_descriptor));
481+ #endif
482+ }
483+
378484__device__ void wgmma_bf16f32_64x256x16 (float r[128 ], std::uint64_t a_descriptor, std::uint64_t b_descriptor) {
485+ // The `bf16` instructions are almost identical to `f16`.
379486#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
380487 asm volatile ( //
381488 " wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 "
@@ -420,6 +527,53 @@ __device__ void wgmma_bf16f32_64x256x16(float r[128], std::uint64_t a_descriptor
420527#endif
421528}
422529
530+ __device__ void wgmma_tf32f32_64x256x16 (float r[128 ], std::uint64_t a_descriptor, std::uint64_t b_descriptor) {
531+ // ! Unlike the `f16` and `bf16` instructions, the `tf32` has fewer operands,
532+ // ! and can't transpose the input matrices!
533+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
534+ asm volatile ( //
535+ " wgmma.mma_async.sync.aligned.m64n256k16.f32.tf32.tf32 "
536+ " {"
537+ " %0, %1, %2, %3, %4, %5, %6, %7, "
538+ " %8, %9, %10, %11, %12, %13, %14, %15, "
539+ " %16, %17, %18, %19, %20, %21, %22, %23, "
540+ " %24, %25, %26, %27, %28, %29, %30, %31, "
541+ " %32, %33, %34, %35, %36, %37, %38, %39, "
542+ " %40, %41, %42, %43, %44, %45, %46, %47, "
543+ " %48, %49, %50, %51, %52, %53, %54, %55, "
544+ " %56, %57, %58, %59, %60, %61, %62, %63, "
545+ " %64, %65, %66, %67, %68, %69, %70, %71, "
546+ " %72, %73, %74, %75, %76, %77, %78, %79, "
547+ " %80, %81, %82, %83, %84, %85, %86, %87, "
548+ " %88, %89, %90, %91, %92, %93, %94, %95, "
549+ " %96, %97, %98, %99, %100, %101, %102, %103, "
550+ " %104, %105, %106, %107, %108, %109, %110, %111, "
551+ " %112, %113, %114, %115, %116, %117, %118, %119, "
552+ " %120, %121, %122, %123, %124, %125, %126, %127"
553+ " }, "
554+ " %128, %129, "
555+ " 1, 1, 1;"
556+ : " =f" (r[0 ]), " =f" (r[1 ]), " =f" (r[2 ]), " =f" (r[3 ]), " =f" (r[4 ]), " =f" (r[5 ]), " =f" (r[6 ]), " =f" (r[7 ]), " =f" (r[8 ]),
557+ " =f" (r[9 ]), " =f" (r[10 ]), " =f" (r[11 ]), " =f" (r[12 ]), " =f" (r[13 ]), " =f" (r[14 ]), " =f" (r[15 ]), " =f" (r[16 ]),
558+ " =f" (r[17 ]), " =f" (r[18 ]), " =f" (r[19 ]), " =f" (r[20 ]), " =f" (r[21 ]), " =f" (r[22 ]), " =f" (r[23 ]), " =f" (r[24 ]),
559+ " =f" (r[25 ]), " =f" (r[26 ]), " =f" (r[27 ]), " =f" (r[28 ]), " =f" (r[29 ]), " =f" (r[30 ]), " =f" (r[31 ]), " =f" (r[32 ]),
560+ " =f" (r[33 ]), " =f" (r[34 ]), " =f" (r[35 ]), " =f" (r[36 ]), " =f" (r[37 ]), " =f" (r[38 ]), " =f" (r[39 ]), " =f" (r[40 ]),
561+ " =f" (r[41 ]), " =f" (r[42 ]), " =f" (r[43 ]), " =f" (r[44 ]), " =f" (r[45 ]), " =f" (r[46 ]), " =f" (r[47 ]), " =f" (r[48 ]),
562+ " =f" (r[49 ]), " =f" (r[50 ]), " =f" (r[51 ]), " =f" (r[52 ]), " =f" (r[53 ]), " =f" (r[54 ]), " =f" (r[55 ]), " =f" (r[56 ]),
563+ " =f" (r[57 ]), " =f" (r[58 ]), " =f" (r[59 ]), " =f" (r[60 ]), " =f" (r[61 ]), " =f" (r[62 ]), " =f" (r[63 ]), " =f" (r[64 ]),
564+ " =f" (r[65 ]), " =f" (r[66 ]), " =f" (r[67 ]), " =f" (r[68 ]), " =f" (r[69 ]), " =f" (r[70 ]), " =f" (r[71 ]), " =f" (r[72 ]),
565+ " =f" (r[73 ]), " =f" (r[74 ]), " =f" (r[75 ]), " =f" (r[76 ]), " =f" (r[77 ]), " =f" (r[78 ]), " =f" (r[79 ]), " =f" (r[80 ]),
566+ " =f" (r[81 ]), " =f" (r[82 ]), " =f" (r[83 ]), " =f" (r[84 ]), " =f" (r[85 ]), " =f" (r[86 ]), " =f" (r[87 ]), " =f" (r[88 ]),
567+ " =f" (r[89 ]), " =f" (r[90 ]), " =f" (r[91 ]), " =f" (r[92 ]), " =f" (r[93 ]), " =f" (r[94 ]), " =f" (r[95 ]), " =f" (r[96 ]),
568+ " =f" (r[97 ]), " =f" (r[98 ]), " =f" (r[99 ]), " =f" (r[100 ]), " =f" (r[101 ]), " =f" (r[102 ]), " =f" (r[103 ]), " =f" (r[104 ]),
569+ " =f" (r[105 ]), " =f" (r[106 ]), " =f" (r[107 ]), " =f" (r[108 ]), " =f" (r[109 ]), " =f" (r[110 ]), " =f" (r[111 ]),
570+ " =f" (r[112 ]), " =f" (r[113 ]), " =f" (r[114 ]), " =f" (r[115 ]), " =f" (r[116 ]), " =f" (r[117 ]), " =f" (r[118 ]),
571+ " =f" (r[119 ]), " =f" (r[120 ]), " =f" (r[121 ]), " =f" (r[122 ]), " =f" (r[123 ]), " =f" (r[124 ]), " =f" (r[125 ]),
572+ " =f" (r[126 ]), " =f" (r[127 ])
573+ : " l" (a_descriptor), " l" (b_descriptor));
574+ #endif
575+ }
576+
423577__device__ void wgmma_commit_group () {
424578#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
425579 asm volatile (" wgmma.commit_group.sync.aligned;" );
@@ -432,9 +586,25 @@ __device__ void wgmma_wait_group() {
432586#endif
433587}
434588
435- __global__ void tops_bf16f32_sm90tc_64x256x16_loop128_cuda_kernel () {
436- __shared__ __nv_bfloat16 a_shared[64 ][16 ];
437- __shared__ __nv_bfloat16 b_shared[256 ][16 ];
589+ __global__ void tops_f16f32_sm90wgmma_64x256x16_loop128_cuda_kernel () {
590+ // 64x256x16 is the largest tile size for `f16` supported on Hopper.
591+ // We can use `half` for type, but `uint16_t` is more portable.
592+ __shared__ std::uint16_t a_shared[64 ][16 ];
593+ __shared__ std::uint16_t b_shared[256 ][16 ];
594+
595+ float c_registers[128 ] = {0 .0f };
596+ std::uint64_t a_descriptor = wgmma_descriptor ((std::uint64_t )a_shared, 128 , 256 , 0 , 0 );
597+ std::uint64_t b_descriptor = wgmma_descriptor ((std::uint64_t )b_shared, 128 * 256 / 8 , 128 , 0 , 0 );
598+ for (int i = 0 ; i != 128 ; ++i) wgmma_bf16f32_64x256x16 (c_registers, a_descriptor, b_descriptor);
599+ wgmma_commit_group ();
600+ wgmma_wait_group ();
601+ }
602+
603+ __global__ void tops_bf16f32_sm90wgmma_64x256x16_loop128_cuda_kernel () {
604+ // 64x256x16 is the largest tile size for `bf16` supported on Hopper.
605+ // We can use `__nv_bfloat16` for type, but `uint16_t` is more portable.
606+ __shared__ std::uint16_t a_shared[64 ][16 ];
607+ __shared__ std::uint16_t b_shared[256 ][16 ];
438608
439609 float c_registers[128 ] = {0 .0f };
440610 std::uint64_t a_descriptor = wgmma_descriptor ((std::uint64_t )a_shared, 128 , 256 , 0 , 0 );
@@ -443,3 +613,28 @@ __global__ void tops_bf16f32_sm90tc_64x256x16_loop128_cuda_kernel() {
443613 wgmma_commit_group ();
444614 wgmma_wait_group ();
445615}
616+
617+ __global__ void tops_tf32f32_sm90wgmma_64x256x16_loop128_cuda_kernel () {
618+ // 64x256x16 is the largest tile size for `tf32` supported on Hopper.
619+ // Four-byte representations should be used for storage. Each entry will
620+ // shifted right by 13 bits before multiplication.
621+ __shared__ std::uint32_t a_shared[64 ][16 ];
622+ __shared__ std::uint32_t b_shared[256 ][16 ];
623+
624+ // TODO: Unlike smaller 2-byte floats, the stride sizes will be different here.
625+ float c_registers[128 ] = {0 .0f };
626+ std::uint64_t a_descriptor = wgmma_descriptor ((std::uint64_t )a_shared, 128 , 256 , 0 , 0 );
627+ std::uint64_t b_descriptor = wgmma_descriptor ((std::uint64_t )b_shared, 128 * 256 / 8 , 128 , 0 , 0 );
628+ for (int i = 0 ; i != 128 ; ++i) wgmma_bf16f32_64x256x16 (c_registers, a_descriptor, b_descriptor);
629+ wgmma_commit_group ();
630+ wgmma_wait_group ();
631+ }
632+
633+ #pragma endregion
634+
635+ /* *
636+ *
637+ * @see "Blackwell Cluster Launch Control" in CUTLASS docs:
638+ * https://github.com/NVIDIA/cutlass/blob/main/media/docs/blackwell_cluster_launch_control.md
639+ *
640+ */
0 commit comments