Skip to content

How to count GPU Tensor operations correctly 🤯

Choose a tag to compare

@ashvardanian ashvardanian released this 11 Feb 12:08
· 62 commits to main since this release

Measuring Tensor-Core throughput is tricky! Many families of matrix-multiplications instructions exist. Practically every Nvidia GPU generation brings new tiles, new numeric types, mixed-precision schemes, and "structured sparsity" models. All of those together form some of the longest PTX IR instructions. To make things worse, across generations, Tensor Core scheduling and collective execution scale are different!

  • Before Volta and Tensor Cores, each GPU thread would execute its own scalar Fused-Multiply-Add — easy-peasy, as long as you know how to choose the optimal grid size for your GPU model.
  • On Volta, with new mma.* instructions and wmma:: intrinsics, 8 threads would execute every tiled Mat-Mul together. This scale of collaboration was creatively called by Nvidia engineers a octet a "quadpair", of course 🤦‍♂️
  • On Ampere, with new wmma.mma.* instructions, all of the 32 threads in a single "warp" would work together. This abstraction makes sense to people familiar with CUDA C++ and how scheduling works on the GPU. Great!
  • On Hopper, things changed again, of course, with wgmma.mma_async.sync.*, which supports basic asynchronous primitives at the hardware level. It has 128 threads across 4 consecutive "warps" forming a "warp group".
  • On Blackwell, you would be wise to expect a new change, and it came with a broader set of functionality refactored into an all-new tcgen05.* namespace of instructions 🧠 🔫

This new PR addresses this by explicitly marking the collaboration "scale" and counting TOPS differently for each family of instructions.


Almost equally tricky is making sure that the PTXAS assembler doesn't optimize out relevant code blocks. In the past, one approach I'd use is putting an impossible condition at the end of a CUDA C++ kernel, like this:

template <typename input_type_, typename output_type_, int m_, int n_, int k_, int repetitions_ = 128>
__device__ inline void tops_tc_cuda_kernel() {
    using namespace nvcuda;
    wmma::fragment<wmma::matrix_a, m_, n_, k_, input_type_, wmma::row_major> a_frag;
    wmma::fragment<wmma::matrix_b, m_, n_, k_, input_type_, wmma::col_major> b_frag;
    wmma::fragment<wmma::accumulator, m_, n_, k_, output_type_> c_frag;
    for (int i = 0; i != repetitions_; ++i) wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
    if (threadIdx.x == 2147483647) wmma::store_matrix_sync(nullptr, c_frag, 16, wmma::mem_row_major);
}

This way, the compiler will see that I'm trying to export the accumulated value and will not remove our mma_sync call, even if the target address is a NULL pointer. Another approach I'd often use in PTX is to define dummy global variables and export a few values there:

.visible .global .align 4 .s32 dummy_sink_s32[32];
.visible .global .align 4 .f32 dummy_sink_f32[32];
.visible .entry tops_f16f32_sm90tc_m64n256k16_loop128_ptx_kernel() {
    ...
loop_exit:
    // Zero argument means - wait for all committed WGMMAs to complete.
    wgmma.wait_group.sync.aligned 0;

    // Use volatile stores to force the accumulator values to be written out.
    // This dummy write (to a global variable) makes the work observable and 
    // prevents the multiplication  pipeline from being optimized out.
    st.global.volatile.f32 [dummy_sink_f32],      accum0;
    st.global.volatile.f32 [dummy_sink_f32+4],    accum1;
    ret;
}

But with WGMMA, the PTXAS tool will optimize our multiplications if the shared-memory tile descriptors aren't valid. Even if it's just for a benchmark. So this PR shows how to assemble valid descriptors 🤗


This PR fixes those issues and adds more PTX kernels to highlight the different aspects of GPGPU development 🤗

Minor

  • Add: f16f32 WMMA variant for Ampere (28e639e)
  • Add: f16f32 MMA variant for Volta (1359ca7)
  • Add: Inline-PTX in C++ for WGMMA (6e16165)
  • Add: WGMMA synchronization (0207843)
  • Add: Inlined PTX kernels in CUDA C++ (e2a1bfc)

Patch

  • Docs: New H200 stats (b5d4610)
  • Docs: Naming temporary compilation results (da36475)
  • Improve: Drop small WGMMA for conciseness (7f63ef2)
  • Fix: Invoke f16f32 in WGMMA (4423421)
  • Fix: tf32 perf and waiting on fences (ea4a3e0)
  • Fix: Counting TOPS across TC generations (85f78c3)
  • Make: Split Hopper and Ampere PTX (733cbac)
  • Make: Target SM 9.0a over SM 9.0 (726c1e1)