Skip to content

Commit ea4a3e0

Browse files
committed
Fix: tf32 perf and waiting on fences
1 parent 85f78c3 commit ea4a3e0

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

less_slow.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,7 +2186,7 @@ BENCHMARK_CAPTURE(
21862186

21872187
extern __global__ void tops_f16f32_sm90wgmma_64x256x16_loop128_cuda_kernel();
21882188
extern __global__ void tops_bf16f32_sm90wgmma_64x256x16_loop128_cuda_kernel();
2189-
extern __global__ void tops_tf32f32_sm90wgmma_64x256x16_loop128_cuda_kernel();
2189+
extern __global__ void tops_tf32f32_sm90wgmma_64x256x8_loop128_cuda_kernel();
21902190

21912191
BENCHMARK_CAPTURE( //
21922192
theoretic_tops_cuda, f16f32_sm90wgmma, tops_f16f32_sm90wgmma_64x256x16_loop128_cuda_kernel, //
@@ -2196,9 +2196,9 @@ BENCHMARK_CAPTURE(
21962196
theoretic_tops_cuda, bf16f32_sm90wgmma, tops_bf16f32_sm90wgmma_64x256x16_loop128_cuda_kernel, //
21972197
64, 256, 16, 90, 128, tensor_core_scale_t::warpgroup_k)
21982198
->MinTime(10);
2199-
BENCHMARK_CAPTURE( //
2200-
theoretic_tops_cuda, tf32f32_sm90wgmma, tops_tf32f32_sm90wgmma_64x256x16_loop128_cuda_kernel, //
2201-
64, 256, 16, 90, 128, tensor_core_scale_t::warpgroup_k)
2199+
BENCHMARK_CAPTURE( //
2200+
theoretic_tops_cuda, tf32f32_sm90wgmma, tops_tf32f32_sm90wgmma_64x256x8_loop128_cuda_kernel, //
2201+
64, 256, 8, 90, 128, tensor_core_scale_t::warpgroup_k)
22022202
->MinTime(10);
22032203

22042204
#include <filesystem> // `std::filesystem::absolute` to locate PTX IR file

less_slow.cu

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -527,12 +527,12 @@ __device__ void wgmma_bf16f32_64x256x16(float r[128], std::uint64_t a_descriptor
527527
#endif
528528
}
529529

530-
__device__ void wgmma_tf32f32_64x256x16(float r[128], std::uint64_t a_descriptor, std::uint64_t b_descriptor) {
530+
__device__ void wgmma_tf32f32_64x256x8(float r[128], std::uint64_t a_descriptor, std::uint64_t b_descriptor) {
531531
//! Unlike the `f16` and `bf16` instructions, the `tf32` has fewer operands,
532532
//! and can't transpose the input matrices!
533533
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
534534
asm volatile( //
535-
"wgmma.mma_async.sync.aligned.m64n256k16.f32.tf32.tf32 "
535+
"wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 "
536536
"{"
537537
"%0, %1, %2, %3, %4, %5, %6, %7, "
538538
"%8, %9, %10, %11, %12, %13, %14, %15, "
@@ -574,15 +574,21 @@ __device__ void wgmma_tf32f32_64x256x16(float r[128], std::uint64_t a_descriptor
574574
#endif
575575
}
576576

577+
__device__ void wgmma_fence() {
578+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
579+
asm volatile("wgmma.fence.sync.aligned;");
580+
#endif
581+
}
582+
577583
__device__ void wgmma_commit_group() {
578584
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
579585
asm volatile("wgmma.commit_group.sync.aligned;");
580586
#endif
581587
}
582588

583-
__device__ void wgmma_wait_group() {
589+
__device__ void wgmma_sync_group() {
584590
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
585-
asm volatile("wgmma.wait_group.sync.aligned 1;");
591+
asm volatile("wgmma.wait_group.sync.aligned 0;");
586592
#endif
587593
}
588594

@@ -595,9 +601,13 @@ __global__ void tops_f16f32_sm90wgmma_64x256x16_loop128_cuda_kernel() {
595601
float c_registers[128] = {0.0f};
596602
std::uint64_t a_descriptor = wgmma_descriptor((std::uint64_t)a_shared, 128, 256, 0, 0);
597603
std::uint64_t b_descriptor = wgmma_descriptor((std::uint64_t)b_shared, 128 * 256 / 8, 128, 0, 0);
598-
for (int i = 0; i != 128; ++i) wgmma_bf16f32_64x256x16(c_registers, a_descriptor, b_descriptor);
599-
wgmma_commit_group();
600-
wgmma_wait_group();
604+
wgmma_fence();
605+
for (int i = 0; i != 128; ++i) {
606+
wgmma_bf16f32_64x256x16(c_registers, a_descriptor, b_descriptor);
607+
wgmma_commit_group();
608+
}
609+
wgmma_sync_group();
610+
if (threadIdx.x == 2147483647) *(std::uint16_t *)nullptr = c_registers[0];
601611
}
602612

603613
__global__ void tops_bf16f32_sm90wgmma_64x256x16_loop128_cuda_kernel() {
@@ -609,25 +619,33 @@ __global__ void tops_bf16f32_sm90wgmma_64x256x16_loop128_cuda_kernel() {
609619
float c_registers[128] = {0.0f};
610620
std::uint64_t a_descriptor = wgmma_descriptor((std::uint64_t)a_shared, 128, 256, 0, 0);
611621
std::uint64_t b_descriptor = wgmma_descriptor((std::uint64_t)b_shared, 128 * 256 / 8, 128, 0, 0);
612-
for (int i = 0; i != 128; ++i) wgmma_bf16f32_64x256x16(c_registers, a_descriptor, b_descriptor);
613-
wgmma_commit_group();
614-
wgmma_wait_group();
622+
wgmma_fence();
623+
for (int i = 0; i != 128; ++i) {
624+
wgmma_bf16f32_64x256x16(c_registers, a_descriptor, b_descriptor);
625+
wgmma_commit_group();
626+
}
627+
wgmma_sync_group();
628+
if (threadIdx.x == 2147483647) *(std::uint16_t *)nullptr = c_registers[0];
615629
}
616630

617-
__global__ void tops_tf32f32_sm90wgmma_64x256x16_loop128_cuda_kernel() {
618-
// 64x256x16 is the largest tile size for `tf32` supported on Hopper.
631+
__global__ void tops_tf32f32_sm90wgmma_64x256x8_loop128_cuda_kernel() {
632+
// 64x256x8 is the largest tile size for `tf32` supported on Hopper.
619633
// Four-byte representations should be used for storage. Each entry will
620634
// shifted right by 13 bits before multiplication.
621-
__shared__ std::uint32_t a_shared[64][16];
622-
__shared__ std::uint32_t b_shared[256][16];
635+
__shared__ std::uint32_t a_shared[64][8];
636+
__shared__ std::uint32_t b_shared[256][8];
623637

624638
// TODO: Unlike smaller 2-byte floats, the stride sizes will be different here.
625639
float c_registers[128] = {0.0f};
626640
std::uint64_t a_descriptor = wgmma_descriptor((std::uint64_t)a_shared, 128, 256, 0, 0);
627641
std::uint64_t b_descriptor = wgmma_descriptor((std::uint64_t)b_shared, 128 * 256 / 8, 128, 0, 0);
628-
for (int i = 0; i != 128; ++i) wgmma_bf16f32_64x256x16(c_registers, a_descriptor, b_descriptor);
629-
wgmma_commit_group();
630-
wgmma_wait_group();
642+
wgmma_fence();
643+
for (int i = 0; i != 128; ++i) {
644+
wgmma_tf32f32_64x256x8(c_registers, a_descriptor, b_descriptor);
645+
wgmma_commit_group();
646+
}
647+
wgmma_sync_group();
648+
if (threadIdx.x == 2147483647) *(std::uint32_t *)nullptr = c_registers[0];
631649
}
632650

633651
#pragma endregion

0 commit comments

Comments
 (0)