@@ -527,12 +527,12 @@ __device__ void wgmma_bf16f32_64x256x16(float r[128], std::uint64_t a_descriptor
527527#endif
528528}
529529
530- __device__ void wgmma_tf32f32_64x256x16 (float r[128 ], std::uint64_t a_descriptor, std::uint64_t b_descriptor) {
530+ __device__ void wgmma_tf32f32_64x256x8 (float r[128 ], std::uint64_t a_descriptor, std::uint64_t b_descriptor) {
531531 // ! Unlike the `f16` and `bf16` instructions, the `tf32` has fewer operands,
532532 // ! and can't transpose the input matrices!
533533#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
534534 asm volatile ( //
535- " wgmma.mma_async.sync.aligned.m64n256k16 .f32.tf32.tf32 "
535+ " wgmma.mma_async.sync.aligned.m64n256k8 .f32.tf32.tf32 "
536536 " {"
537537 " %0, %1, %2, %3, %4, %5, %6, %7, "
538538 " %8, %9, %10, %11, %12, %13, %14, %15, "
@@ -574,15 +574,21 @@ __device__ void wgmma_tf32f32_64x256x16(float r[128], std::uint64_t a_descriptor
574574#endif
575575}
576576
577+ __device__ void wgmma_fence () {
578+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
579+ asm volatile (" wgmma.fence.sync.aligned;" );
580+ #endif
581+ }
582+
577583__device__ void wgmma_commit_group () {
578584#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
579585 asm volatile (" wgmma.commit_group.sync.aligned;" );
580586#endif
581587}
582588
583- __device__ void wgmma_wait_group () {
589+ __device__ void wgmma_sync_group () {
584590#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
585- asm volatile (" wgmma.wait_group.sync.aligned 1 ;" );
591+ asm volatile (" wgmma.wait_group.sync.aligned 0 ;" );
586592#endif
587593}
588594
@@ -595,9 +601,13 @@ __global__ void tops_f16f32_sm90wgmma_64x256x16_loop128_cuda_kernel() {
595601 float c_registers[128 ] = {0 .0f };
596602 std::uint64_t a_descriptor = wgmma_descriptor ((std::uint64_t )a_shared, 128 , 256 , 0 , 0 );
597603 std::uint64_t b_descriptor = wgmma_descriptor ((std::uint64_t )b_shared, 128 * 256 / 8 , 128 , 0 , 0 );
598- for (int i = 0 ; i != 128 ; ++i) wgmma_bf16f32_64x256x16 (c_registers, a_descriptor, b_descriptor);
599- wgmma_commit_group ();
600- wgmma_wait_group ();
604+ wgmma_fence ();
605+ for (int i = 0 ; i != 128 ; ++i) {
606+ wgmma_bf16f32_64x256x16 (c_registers, a_descriptor, b_descriptor);
607+ wgmma_commit_group ();
608+ }
609+ wgmma_sync_group ();
610+ if (threadIdx .x == 2147483647 ) *(std::uint16_t *)nullptr = c_registers[0 ];
601611}
602612
603613__global__ void tops_bf16f32_sm90wgmma_64x256x16_loop128_cuda_kernel () {
@@ -609,25 +619,33 @@ __global__ void tops_bf16f32_sm90wgmma_64x256x16_loop128_cuda_kernel() {
609619 float c_registers[128 ] = {0 .0f };
610620 std::uint64_t a_descriptor = wgmma_descriptor ((std::uint64_t )a_shared, 128 , 256 , 0 , 0 );
611621 std::uint64_t b_descriptor = wgmma_descriptor ((std::uint64_t )b_shared, 128 * 256 / 8 , 128 , 0 , 0 );
612- for (int i = 0 ; i != 128 ; ++i) wgmma_bf16f32_64x256x16 (c_registers, a_descriptor, b_descriptor);
613- wgmma_commit_group ();
614- wgmma_wait_group ();
622+ wgmma_fence ();
623+ for (int i = 0 ; i != 128 ; ++i) {
624+ wgmma_bf16f32_64x256x16 (c_registers, a_descriptor, b_descriptor);
625+ wgmma_commit_group ();
626+ }
627+ wgmma_sync_group ();
628+ if (threadIdx .x == 2147483647 ) *(std::uint16_t *)nullptr = c_registers[0 ];
615629}
616630
617- __global__ void tops_tf32f32_sm90wgmma_64x256x16_loop128_cuda_kernel () {
618- // 64x256x16 is the largest tile size for `tf32` supported on Hopper.
631+ __global__ void tops_tf32f32_sm90wgmma_64x256x8_loop128_cuda_kernel () {
632+ // 64x256x8 is the largest tile size for `tf32` supported on Hopper.
619633 // Four-byte representations should be used for storage. Each entry will
620634 // shifted right by 13 bits before multiplication.
621- __shared__ std::uint32_t a_shared[64 ][16 ];
622- __shared__ std::uint32_t b_shared[256 ][16 ];
635+ __shared__ std::uint32_t a_shared[64 ][8 ];
636+ __shared__ std::uint32_t b_shared[256 ][8 ];
623637
624638 // TODO: Unlike smaller 2-byte floats, the stride sizes will be different here.
625639 float c_registers[128 ] = {0 .0f };
626640 std::uint64_t a_descriptor = wgmma_descriptor ((std::uint64_t )a_shared, 128 , 256 , 0 , 0 );
627641 std::uint64_t b_descriptor = wgmma_descriptor ((std::uint64_t )b_shared, 128 * 256 / 8 , 128 , 0 , 0 );
628- for (int i = 0 ; i != 128 ; ++i) wgmma_bf16f32_64x256x16 (c_registers, a_descriptor, b_descriptor);
629- wgmma_commit_group ();
630- wgmma_wait_group ();
642+ wgmma_fence ();
643+ for (int i = 0 ; i != 128 ; ++i) {
644+ wgmma_tf32f32_64x256x8 (c_registers, a_descriptor, b_descriptor);
645+ wgmma_commit_group ();
646+ }
647+ wgmma_sync_group ();
648+ if (threadIdx .x == 2147483647 ) *(std::uint32_t *)nullptr = c_registers[0 ];
631649}
632650
633651#pragma endregion
0 commit comments