Skip to content

Commit 7c735d7

Browse files
authored
Merge pull request #34 from ashvardanian/wgmma-stats
Counting Tensor Ops Correctly 🤯 Measuring Tensor-Core throughput is tricky! Many families of matrix-multiplications instructions exist. Practically every Nvidia GPU generation brings new tiles, new numeric types, mixed-precision schemes, and "structured sparsity" models. All of those together form [some of the longest PTX IR instructions](https://ashvardanian.com/posts/longest-ptx-instruction/). To make things worse, across generations, Tensor Core scheduling and collective execution scale are different! - Before Volta and Tensor Cores, each GPU thread would execute its own scalar Fused-Multiply-Add — easy-peasy, as long as you know how to choose the optimal grid size for your GPU model. - On Volta, with new `mma.*` instructions and `wmma::` intrinsics, 8 threads would execute every tiled Mat-Mul together. This scale of collaboration was creatively called by Nvidia engineers ~~a octet~~ a "quadpair", of course 🤦‍♂️ - On Ampere, with new `wmma.mma.*` instructions, all of the 32 threads in a single "warp" would work together. This abstraction makes sense to people familiar with CUDA C++ and how scheduling works on the GPU. Great! On Hopper, things changed again, of course, with `wgmma.mma_async.sync.*`, which supports basic asynchronous primitives at the hardware level. It has 128 threads across 4 consecutive "warps" forming a "warp group." - On Blackwell, you would be wise to expect a new change, and it came with a broader set of functionality refactored into an all-new `tcgen05.*` namespace of instructions 🧠 🔫 This new PR addresses this by explicitly marking the collaboration "scale" and counting TOPS differently for each family of instructions. --- This way, the compiler will see that I'm trying to export the accumulated value and will not remove our `mma_sync` call, even if the target address is a NULL pointer. Another approach I'd often use in PTX is to define dummy global variables and export a few values there: ```cuda template <typename input_type_, typename output_type_, int m_, int n_, int k_, int repetitions_ = 128> __device__ inline void tops_tc_cuda_kernel() { using namespace nvcuda; wmma::fragment<wmma::matrix_a, m_, n_, k_, input_type_, wmma::row_major> a_frag; wmma::fragment<wmma::matrix_b, m_, n_, k_, input_type_, wmma::col_major> b_frag; wmma::fragment<wmma::accumulator, m_, n_, k_, output_type_> c_frag; for (int i = 0; i != repetitions_; ++i) wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); if (threadIdx.x == 2147483647) wmma::store_matrix_sync(nullptr, c_frag, 16, wmma::mem_row_major); } ``` This way, the compiler will see that I'm trying to export the accumulated value and will not remove our `mma_sync` call, even if the target address is a NULL pointer. Another approach I'd often use in PTX is to define dummy global variables and export a few values there: ```ptx .visible .global .align 4 .s32 dummy_sink_s32[32]; .visible .global .align 4 .f32 dummy_sink_f32[32]; .visible .entry tops_f16f32_sm90tc_m64n256k16_loop128_ptx_kernel() { ... loop_exit: // Zero argument means - wait for all committed WGMMAs to complete. wgmma.wait_group.sync.aligned 0; // Use volatile stores to force the accumulator values to be written out. // This dummy write (to a global variable) makes the work observable and // prevents the multiplication pipeline from being optimized out. st.global.volatile.f32 [dummy_sink_f32], accum0; st.global.volatile.f32 [dummy_sink_f32+4], accum1; ret; } ``` But with WGMMA, the `PTXAS` tool will optimize our multiplications if the shared-memory tile descriptors aren't valid. Even if it's just for a benchmark. So this PR shows how to assemble valid descriptors 🤗 --- This PR fixes those issues and adds more PTX kernels to highlight the different aspects of GPGPU development 🤗
2 parents afcf491 + b5d4610 commit 7c735d7

File tree

7 files changed

+1250
-467
lines changed

7 files changed

+1250
-467
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ build_release/
77

88
# Temporary binaries
99
/tmp/
10-
less_slow_from_ptx.cubin
1110
less_slow_from_cu.cubin
1211
less_slow_from_cu.ptx
12+
less_slow_sm70_from_ptx.cubin
13+
less_slow_sm80_from_ptx.cubin
14+
less_slow_sm90a_from_ptx.cubin

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ endif()
334334
# https://cmake.org/cmake/help/latest/variable/CMAKE_LANG_COMPILER_ID.html
335335
if(CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA" OR CMAKE_CUDA_COMPILER_ID STREQUAL "NVHPC")
336336
set_target_properties(less_slow PROPERTIES POSITION_INDEPENDENT_CODE ON)
337-
set_target_properties(less_slow PROPERTIES CUDA_ARCHITECTURES "70;75;80;89;90")
337+
set_target_properties(less_slow PROPERTIES CUDA_ARCHITECTURES "70;75;80;89;90a")
338338
target_compile_options(less_slow PRIVATE
339339
-Wfatal-errors # Stop on first error
340340
-fopenmp # OpenMP support, also requires linking
@@ -434,7 +434,7 @@ if(USE_NVIDIA_CCCL)
434434
# target_link_libraries(less_slow PRIVATE nvidia::cutlass::cutlass)
435435

436436
# List the PTX files you want to copy
437-
set(PTX_FILES less_slow_sm70.ptx less_slow_sm90a.ptx)
437+
set(PTX_FILES less_slow_sm70.ptx less_slow_sm80.ptx less_slow_sm90a.ptx)
438438

439439
# Loop over each PTX file and add a custom command to copy it
440440
foreach(PTX ${PTX_FILES})

less_slow.cpp

Lines changed: 191 additions & 79 deletions
Large diffs are not rendered by default.

less_slow.cu

Lines changed: 311 additions & 48 deletions
Large diffs are not rendered by default.

less_slow_sm70.ptx

Lines changed: 103 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,16 @@
1818
* You can validate this file by asking the Nvidia PTX Assembler to compile it
1919
* to `.cubin` for some target architecture:
2020
*
21-
* $ ptxas -o less_slow_from_ptx.cubin -arch=sm_70 less_slow_sm70.ptx
22-
* $ cuobjdump -sass less_slow_from_ptx.cubin | grep -i mma
21+
* $ ptxas -o less_slow_sm70_from_ptx.cubin -arch=sm_70 less_slow_sm70.ptx
22+
* $ cuobjdump -sass less_slow_sm70_from_ptx.cubin | grep -i mma
23+
*
24+
* Assuming how aggressively NVCC unrolls loops and the number of kernels in
25+
* this file, you may want to deduplicate them:
26+
*
27+
* $ cuobjdump -sass less_slow_sm70_from_ptx.cubin | grep -i mma | \
28+
* $ sed -r 's/\/\*[^*]+\*\///g' | \
29+
* $ sed -r 's/^[[:space:]]+//; s/[[:space:]]+$//' | \
30+
* $ sort -u
2331
*
2432
* @section Register File
2533
*
@@ -37,18 +45,95 @@
3745
.target sm_70 // Target architecture (SM 7.0 - Volta GPUs)
3846
.address_size 64 // 64-bit addressing
3947

40-
.visible .entry tops_f16f16_sm70tc_16x16x16_loop128_ptx_kernel()
48+
.visible .entry tops_f16f16_sm70mma_8x8x4_loop128_ptx_kernel()
4149
{
4250
// Accumulator registers used for both input and output of the MMA operation
4351
.reg .b32 accum_0, accum_1, accum_2, accum_3;
4452

45-
// Registers to hold packed 16-bit data for matrix a (8 registers)
46-
.reg .b32 matrix_a_0, matrix_a_1, matrix_a_2, matrix_a_3,
47-
matrix_a_4, matrix_a_5, matrix_a_6, matrix_a_7;
53+
// Registers to hold packed pairs of 16-bit data for matrix a (2 registers)
54+
.reg .b32 matrix_a_0, matrix_a_1;
4855

49-
// Registers to hold packed 16-bit data for matrix b (8 registers)
50-
.reg .b32 matrix_b_0, matrix_b_1, matrix_b_2, matrix_b_3,
51-
matrix_b_4, matrix_b_5, matrix_b_6, matrix_b_7;
56+
// Registers to hold packed pairs of 16-bit data for matrix b (2 registers)
57+
.reg .b32 matrix_b_0, matrix_b_1;
58+
59+
// General-purpose registers for loop control and constant values
60+
.reg .b32 loop_counter, loop_limit, packed_const;
61+
62+
// Predicate register for conditional branching (loop exit)
63+
.reg .pred exit_predicate;
64+
65+
// Set up loop counter and loop limit
66+
mov.u32 loop_counter, 0;
67+
mov.u32 loop_limit, 128;
68+
69+
// Zero-initialize the accumulator registers
70+
mov.f32 accum_0, 0.0;
71+
mov.f32 accum_1, 0.0;
72+
mov.f32 accum_2, 0.0;
73+
mov.f32 accum_3, 0.0;
74+
75+
// Initialize constant for packed matrix data (placeholder)
76+
mov.b32 packed_const, 0x00010001;
77+
78+
// Initialize matrix a registers with the packed constant
79+
mov.b32 matrix_a_0, packed_const;
80+
mov.b32 matrix_a_1, packed_const;
81+
82+
// Initialize matrix b registers with the packed constant
83+
mov.b32 matrix_b_0, packed_const;
84+
mov.b32 matrix_b_1, packed_const;
85+
86+
// The main loop will repeat for 128 iterations
87+
loop_start:
88+
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
89+
@exit_predicate bra loop_end;
90+
91+
mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16
92+
{ accum_0, accum_1, accum_2, accum_3 },
93+
{ matrix_a_0, matrix_a_1 },
94+
{ matrix_b_0, matrix_b_1 },
95+
{ accum_0, accum_1, accum_2, accum_3 };
96+
97+
// Increment the loop counter
98+
add.u32 loop_counter, loop_counter, 1;
99+
100+
// Branch back to the beginning of the loop
101+
bra loop_start;
102+
103+
loop_end:
104+
// If we simply exit, the computation will be optimized out!
105+
// Instead, let's check for an impossible condition, like if the thread ID
106+
// is equal to `UINT_MAX`, and if so - write accumulators to global memory
107+
// NULL address.
108+
.reg .u32 tid;
109+
.reg .pred impossible_predicate;
110+
mov.u32 tid, %tid.x; //? Special system registers start with `%`
111+
setp.ne.u32 impossible_predicate, tid, 0xFFFFFFFF;
112+
@impossible_predicate bra loop_exit;
113+
114+
// Write into memory:
115+
.reg .u64 store_ptr;
116+
mov.u64 store_ptr, 0;
117+
st.global.f32 [store_ptr], accum_0;
118+
st.global.f32 [store_ptr+4], accum_1;
119+
st.global.f32 [store_ptr+8], accum_2;
120+
st.global.f32 [store_ptr+12], accum_3;
121+
122+
loop_exit:
123+
ret;
124+
}
125+
126+
.visible .entry tops_f16f32_sm70mma_8x8x4_loop128_ptx_kernel()
127+
{
128+
// Accumulator registers used for both input and output of the MMA operation
129+
.reg .b32 accum_0, accum_1, accum_2, accum_3,
130+
accum_4, accum_5, accum_6, accum_7;
131+
132+
// Registers to hold packed 16-bit data for matrix a (4 registers)
133+
.reg .b32 matrix_a_0, matrix_a_1, matrix_a_2, matrix_a_3;
134+
135+
// Registers to hold packed 16-bit data for matrix b (4 registers)
136+
.reg .b32 matrix_b_0, matrix_b_1, matrix_b_2, matrix_b_3;
52137

53138
// General-purpose registers for loop control and constant values
54139
.reg .b32 loop_counter, loop_limit, packed_const;
@@ -74,33 +159,25 @@
74159
mov.b32 matrix_a_1, packed_const;
75160
mov.b32 matrix_a_2, packed_const;
76161
mov.b32 matrix_a_3, packed_const;
77-
mov.b32 matrix_a_4, packed_const;
78-
mov.b32 matrix_a_5, packed_const;
79-
mov.b32 matrix_a_6, packed_const;
80-
mov.b32 matrix_a_7, packed_const;
81162

82163
// Initialize matrix b registers with the packed constant
83164
mov.b32 matrix_b_0, packed_const;
84165
mov.b32 matrix_b_1, packed_const;
85166
mov.b32 matrix_b_2, packed_const;
86167
mov.b32 matrix_b_3, packed_const;
87-
mov.b32 matrix_b_4, packed_const;
88-
mov.b32 matrix_b_5, packed_const;
89-
mov.b32 matrix_b_6, packed_const;
90-
mov.b32 matrix_b_7, packed_const;
91168

92169
// The main loop will repeat for 128 iterations
93170
loop_start:
94171
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
95172
@exit_predicate bra loop_end;
96173

97-
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16
98-
{ accum_0, accum_1, accum_2, accum_3 },
99-
{ matrix_a_0, matrix_a_1, matrix_a_2, matrix_a_3,
100-
matrix_a_4, matrix_a_5, matrix_a_6, matrix_a_7 },
101-
{ matrix_b_0, matrix_b_1, matrix_b_2, matrix_b_3,
102-
matrix_b_4, matrix_b_5, matrix_b_6, matrix_b_7 },
103-
{ accum_0, accum_1, accum_2, accum_3 };
174+
mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32
175+
{ accum_0, accum_1, accum_2, accum_3,
176+
accum_4, accum_5, accum_6, accum_7 },
177+
{ matrix_a_0, matrix_a_1 },
178+
{ matrix_b_0, matrix_b_1 },
179+
{ accum_0, accum_1, accum_2, accum_3,
180+
accum_4, accum_5, accum_6, accum_7 };
104181

105182
// Increment the loop counter
106183
add.u32 loop_counter, loop_counter, 1;
@@ -147,7 +224,8 @@ loop_exit:
147224
* with both arguments in shared memory!
148225
*
149226
* Because only one `.version` directive can be placed in each file, for newer
150-
* kernels, go to `less_slow_sm90a.ptx`.
227+
* kernels, go to `less_slow_sm80.ptx` for Ampere and `less_slow_sm90a.ptx`
228+
* for Hopper.
151229
*
152230
* @see PTX module-level directives:
153231
* https://docs.nvidia.com/cuda/parallel-thread-execution/#ptx-module-directives

0 commit comments

Comments
 (0)