Skip to content

Commit bb78fae

Browse files
authored
[Blackwell] Support narrower TMEM messages and shapes (#5945)
Narrower message widths can alleviate register pressure avoiding spills in workloads that require a large number of per-thread registers * Refactor separates tmem atom derived message constants from workload derived message constraints * Narrowing occurs when a single message would require >=50% of available thread registers (128) and the workload requires all available registers (256) to complete * Adds tcgen05.st/ld..16x256b codegen support. With subsequent work this can pair with downstream stmatrix ops for lower latency epilogues
1 parent bca378d commit bb78fae

File tree

2 files changed

+284
-125
lines changed

2 files changed

+284
-125
lines changed

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
341341
tt.return
342342
}
343343
}
344+
345+
// -----
346+
347+
348+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
349+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 256], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
350+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 256, unpacked = true>
351+
352+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} {
353+
// CHECK-LABEL: @tensor_memory_ld_128x256
354+
// CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32
355+
// CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32
356+
// CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32
357+
// CHECK: tcgen05.st.sync.aligned.32x32b.x64.b32
358+
// CHECK: tcgen05.wait::st.sync.aligned
359+
// CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32
360+
// CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32
361+
// CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32
362+
// CHECK: tcgen05.ld.sync.aligned.32x32b.x64.b32
363+
// CHECK: tcgen05.wait::ld.sync.aligned
364+
tt.func public @tensor_memory_ld_128x256(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
365+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked1>
366+
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked1>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
367+
%20 = ttng.tmem_load %0 : !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf32, #blocked1>
368+
tt.return
369+
}
370+
}

0 commit comments

Comments
 (0)