Skip to content

Commit cf9b4ea

Browse files
authored
[Backend] Allow splitting block_m=64 TMEM along N (#7589)
There are a few restrictions: the subslice has to be at least 64x2. This is enforced by the fact that you can't actually construct the right layout for this at the moment. The second restriction is the N for the split must be even, otherwise it slices into one of the 64x2 chunks.
1 parent 93367dc commit cf9b4ea

File tree

11 files changed

+361
-46
lines changed

11 files changed

+361
-46
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
3535
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc"
3636

37+
namespace mlir::triton::nvidia_gpu::impl {
38+
LogicalResult verifyMMAv5Op(Operation *op);
39+
} // namespace mlir::triton::nvidia_gpu::impl
40+
3741
#define GET_ATTRDEF_CLASSES
3842
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc"
3943

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,9 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
5454
"setIsAsync",
5555
(ins "bool":$isAsync)>,
5656
];
57+
58+
let verify = [{
59+
return ::mlir::triton::nvidia_gpu::impl::verifyMMAv5Op($_op);
60+
}];
5761
}
5862
#endif // TRITON_NVIDIAGPU_OP_INTERFACES

lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ Attribute getTmemLoadStoreLayout32x32b(unsigned M, unsigned N,
116116
unsigned numWarpGroups = numWarps / 4;
117117
if (numBlocks == 1) {
118118
// Split along the N dimension
119-
sizePerThread = {1, N / (numWarpGroups * 2)};
119+
sizePerThread = {1, ceil<unsigned>(N, numWarpGroups * 2)};
120120
threadsPerWarp = {16, 2};
121121
warpsPerCTA = {4, numWarpGroups};
122122
} else {
123-
sizePerThread = {1, N / 2};
123+
sizePerThread = {1, ceil<unsigned>(N, 2)};
124124
threadsPerWarp = {16, 2};
125125
warpsPerCTA = {0, 0};
126126
// Distribute at most as many warp groups as there is blocks
@@ -138,7 +138,7 @@ Attribute getTmemLoadStoreLayout32x32b(unsigned M, unsigned N,
138138
warpsPerCTA = {4 * numWarpGroups, 1};
139139
} else {
140140
// Split along N dimension
141-
sizePerThread = {1, N / numWarpGroups};
141+
sizePerThread = {1, ceil<unsigned>(N, numWarpGroups)};
142142
threadsPerWarp = {32, 1};
143143
warpsPerCTA = {4, numWarpGroups};
144144
}
@@ -223,6 +223,22 @@ bool isDistributedLayoutTMemCompatible(Operation *op,
223223
return areLayoutsEquivalent(tensorType.getShape(), layout, enc);
224224
}
225225

226+
LogicalResult impl::verifyMMAv5Op(Operation *op) {
227+
auto isInterleaved = [](MemDescType memdesc) {
228+
auto enc = dyn_cast<TensorMemoryEncodingAttr>(memdesc.getEncoding());
229+
return enc && getTmemAllocSizes(memdesc).numRows != 64 &&
230+
enc.getBlockM() == 64;
231+
};
232+
233+
auto itf = cast<MMAv5OpInterface>(op);
234+
if (isInterleaved(itf.getA().getType()) &&
235+
isInterleaved(itf.getAccumulator().getType())) {
236+
return op->emitOpError(
237+
"does not support blockM=64 with interleaved blocks in TMEM layout");
238+
}
239+
return success();
240+
}
241+
226242
} // namespace nvidia_gpu
227243
} // namespace triton
228244
} // namespace mlir

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,8 +575,11 @@ LogicalResult TMEMSubSliceOp::verify() {
575575
srcTy.getEncoding());
576576
if (!encoding)
577577
return emitOpError("The source must be a tensor memory buffer.");
578-
if (encoding.getBlockM() != 128)
579-
return emitOpError("The source must be a 128xN layout.");
578+
if (!llvm::is_contained({64, 128}, encoding.getBlockM())) {
579+
return emitOpError("The source tensor memory descriptor must have a 128xN "
580+
"or 64xN layout, got block_m=")
581+
<< encoding.getBlockM();
582+
}
580583
auto dstTy = cast<triton::gpu::MemDescType>(getResult().getType());
581584
auto dstEncoding = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
582585
dstTy.getEncoding());

lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,10 @@ static Operation *getAlloc(Value value) {
187187
value = reinterpOp.getSrc();
188188
continue;
189189
}
190+
if (auto slice = value.getDefiningOp<TMEMSubSliceOp>()) {
191+
value = slice.getSrc();
192+
continue;
193+
}
190194
auto arg = dyn_cast<BlockArgument>(value);
191195
if (!arg || !isa<triton::gpu::WarpSpecializePartitionsOp>(
192196
arg.getOwner()->getParentOp()))

python/test/gluon/test_frontend.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,6 @@ def kernel(src_layout: ttgl.constexpr, dst_layout: ttgl.constexpr):
9898
assert "layout conversion from BlockedLayout(size_per_thread=[2]" in str(e.value.__cause__)
9999
assert "to AutoLayout() is not trivial" in str(e.value.__cause__)
100100

101-
with pytest.raises(CompilationError) as e:
102-
src_layout: ttgl.constexpr = ttgl.AutoLayout()
103-
dst_layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
104-
kernel.warmup(src_layout, dst_layout, grid=(1, ))
105-
106-
assert "layout conversion from AutoLayout()" in str(e.value.__cause__)
107-
assert "to BlockedLayout(size_per_thread=[2]" in str(e.value.__cause__)
108-
assert "is not trivial" in str(e.value.__cause__)
109-
110101

111102
@gluon.jit
112103
def shared_memory_kernel(XBLOCK: ttgl.constexpr, YBLOCK: ttgl.constexpr, layout_a: ttgl.constexpr,

python/test/unit/blackwell/test_tmem.py

Lines changed: 183 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,20 @@
55
import triton
66
from triton.backends.compiler import GPUTarget
77

8+
from triton.experimental import gluon
9+
from triton.experimental.gluon import language as ttgl
10+
from triton.experimental.gluon.language.nvidia.blackwell import (
11+
TensorMemoryLayout,
12+
allocate_tensor_memory,
13+
get_tmem_32x32b_reg_layout,
14+
mbarrier,
15+
tcgen05_mma,
16+
tcgen05_commit,
17+
)
818

9-
def test_tmem_copy_2d():
10-
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 10:
11-
pytest.skip("Test requires Blackwell target.")
1219

20+
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10")
21+
def test_tmem_copy_2d():
1322
device = "cuda"
1423

1524
smem_h = 256
@@ -89,3 +98,174 @@ def test_tmem_copy_2d():
8998
for i in range(4):
9099
# Copied values are duplicated across warps
91100
assert torch.equal(x[m * 32:(m + 1) * 32], z_tri[32 * i:32 * (i + 1), col_offset:(col_offset + 4)])
101+
102+
103+
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10")
104+
def test_tmem_subslice_block_m_64():
105+
106+
@gluon.jit
107+
def kernel(s_ptr, out_ptr):
108+
BLOCK_M: ttgl.constexpr = 64
109+
N: ttgl.constexpr = 128
110+
BLOCK_N: ttgl.constexpr = 64
111+
112+
tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), unpacked=True)
113+
s_tmem = allocate_tensor_memory(ttgl.float32, (BLOCK_M, N), layout=tmem_layout)
114+
o_tmem = allocate_tensor_memory(ttgl.float32, (BLOCK_M, N), layout=tmem_layout)
115+
116+
layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, (BLOCK_M, N), num_warps=4)
117+
118+
offsets = ttgl.arange(0, BLOCK_M)[:, None] * N + ttgl.arange(0, N)[None, :]
119+
offsets = ttgl.convert_layout(offsets, layout)
120+
s = ttgl.load(s_ptr + offsets)
121+
122+
s_tmem.store(s)
123+
o_tmem.store(s)
124+
125+
p_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), unpacked=False)
126+
p_tmem = s_tmem.slice(0, N // 2)._reinterpret(ttgl.float16, [BLOCK_M, N], p_tmem_layout)
127+
p_tmem.store(ttgl.full((BLOCK_M, N), 0.0, dtype=ttgl.float16, layout=layout))
128+
129+
d1_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, 1), unpacked=True)
130+
d1_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, 1, (BLOCK_M, 1), num_warps=4)
131+
132+
m_tmem = s_tmem.slice(N // 4, 1)._reinterpret(ttgl.float32, [BLOCK_M, 1], d1_tmem_layout)
133+
m_tmem.store(ttgl.full((BLOCK_M, 1), 2.0, dtype=ttgl.float32, layout=d1_layout))
134+
l_tmem = s_tmem.slice(N // 4 + 1, 1)._reinterpret(ttgl.float32, [BLOCK_M, 1], d1_tmem_layout)
135+
l_tmem.store(ttgl.full((BLOCK_M, 1), 3.0, dtype=ttgl.float32, layout=d1_layout))
136+
a_tmem = s_tmem.slice(N // 4 + 2, 1)._reinterpret(ttgl.float32, [BLOCK_M, 1], d1_tmem_layout)
137+
a_tmem.store(ttgl.full((BLOCK_M, 1), 4.0, dtype=ttgl.float32, layout=d1_layout))
138+
139+
s = s_tmem.load(layout)
140+
141+
ttgl.store(out_ptr + offsets, s)
142+
143+
torch.manual_seed(0)
144+
s = torch.randn((64, 128), dtype=torch.float32, device="cuda")
145+
146+
out_tri = torch.empty_like(s)
147+
compiled = kernel[(1, )](s, out_tri)
148+
149+
ttgir = compiled.asm["ttgir"]
150+
# Check that we have two 64x128xf32 allocations.
151+
assert ttgir.count("ttng.tmem_alloc") == 2
152+
assert ttgir.count("ttng.tmem_alloc : () -> !ttg.memdesc<64x128xf32") == 2
153+
154+
# Check that we allocated only 128 columns of TMEM.
155+
llir = compiled.asm["llir"]
156+
assert llir.count("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$1], 128")
157+
158+
# Given TMEM[0:32] is the slice of TMEM for warpgroup 0, the expected layout
159+
# of S is
160+
#
161+
# TMEM[0:16] = S[0:16, 0:64]
162+
# TMEM[16:32] = S[0:16, 64:128]
163+
#
164+
# When slicing S to obtain P, we expect it to overlap with the left half,
165+
# i.e. S[0:16, 0:32] and S[0:16, 64:96].
166+
out_ref = s
167+
out_ref[:, 0:32] = 0.0
168+
out_ref[:, 64:96] = 0.0
169+
170+
# Given S = [s0, s1, s2, s3], they are arranged like
171+
#
172+
# TMEM[0:16] = [s0, s1]
173+
# TMEM[16:32] = [s2, s3]
174+
#
175+
# Thus slicing S at N//4 will obtain an offset to the beginning of s1.
176+
out_ref[:, 32] = 2.0
177+
out_ref[:, 33] = 3.0
178+
out_ref[:, 34] = 4.0
179+
180+
torch.testing.assert_close(out_ref, out_tri, atol=0, rtol=0)
181+
182+
183+
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10")
184+
def test_block_m_64_mma():
185+
186+
@gluon.jit
187+
def kernel(a_ptr, b_ptr, c_ptr, d_ptr):
188+
BLOCK_M: ttgl.constexpr = 64
189+
N: ttgl.constexpr = 128
190+
BLOCK_N: ttgl.constexpr = 64
191+
192+
a_offsets = ttgl.arange(0, BLOCK_M)[:, None] * N + ttgl.arange(0, N)[None, :]
193+
b_offsets = ttgl.arange(0, N)[:, None] * N + ttgl.arange(0, N)[None, :]
194+
195+
a_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, (BLOCK_M, N), num_warps=4)
196+
b_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0])
197+
a_offsets = ttgl.convert_layout(a_offsets, a_layout)
198+
b_offsets = ttgl.convert_layout(b_offsets, b_layout)
199+
200+
a = ttgl.load(a_ptr + a_offsets)
201+
b = ttgl.load(b_ptr + b_offsets)
202+
c = ttgl.load(c_ptr + a_offsets)
203+
204+
a_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), unpacked=False)
205+
acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), unpacked=True)
206+
al_tmem = allocate_tensor_memory(ttgl.float16, (BLOCK_M, N), layout=a_tmem_layout)
207+
ar_tmem = allocate_tensor_memory(ttgl.float16, (BLOCK_M, N), layout=a_tmem_layout)
208+
acc_tmem = allocate_tensor_memory(ttgl.float32, (BLOCK_M, N), layout=acc_tmem_layout)
209+
210+
a0, a1 = a.reshape((BLOCK_M, 2, N // 2)).permute(0, 2, 1).split()
211+
212+
al = ttgl.join(a0, a1).permute(0, 2, 1).reshape((BLOCK_M, N))
213+
ar = ttgl.join(a1, a0).permute(0, 2, 1).reshape((BLOCK_M, N))
214+
215+
al_tmem.store(ttgl.convert_layout(al, a_layout, assert_trivial=True))
216+
ar_tmem.store(ttgl.convert_layout(ar, a_layout, assert_trivial=True))
217+
218+
b_shared_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=32, element_bitwidth=16, rank=2)
219+
b_shared = ttgl.allocate_shared_memory(ttgl.float16, [N, N], layout=b_shared_layout)
220+
b_shared.store(b)
221+
222+
acc_tmem.store(c)
223+
224+
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], ttgl.constexpr(mbarrier.MBarrierLayout()))
225+
mbarrier.init(bar, count=1)
226+
227+
# This is a manually tiled MMA where LHS is in TMEM with blockM=64,
228+
# where we circumvent the limitation that LHS and accumulator need to
229+
# share the same TMEM rows by storing the LHS twice.
230+
#
231+
# TMEM al ar c
232+
# [0, 16) a0 a1 c0
233+
# [16, 32) a1 a0 c1
234+
#
235+
# d0 = a0 @ b00 + a1 @ b10 + c0
236+
# d1 = a0 @ b10 + a1 @ b11 + c1
237+
238+
N2: ttgl.constexpr = N // 2
239+
c0 = acc_tmem.slice(0, N2)
240+
c1 = acc_tmem.slice(N2, N2)
241+
242+
tcgen05_mma(al_tmem.slice(0, N2), b_shared.slice(0, N2, dim=0).slice(0, N2, dim=1), c0)
243+
tcgen05_mma(ar_tmem.slice(0, N2), b_shared.slice(N2, N2, dim=0).slice(0, N2, dim=1), c0)
244+
tcgen05_mma(ar_tmem.slice(N2, N2), b_shared.slice(0, N2, dim=0).slice(N2, N2, dim=1), c1)
245+
tcgen05_mma(al_tmem.slice(N2, N2), b_shared.slice(N2, N2, dim=0).slice(N2, N2, dim=1), c1)
246+
247+
tcgen05_commit(bar)
248+
mbarrier.wait(bar, 0)
249+
mbarrier.invalidate(bar)
250+
251+
d = acc_tmem.load(a_layout)
252+
ttgl.store(d_ptr + a_offsets, d)
253+
254+
torch.manual_seed(0)
255+
a = torch.randn((64, 128), dtype=torch.float16, device="cuda")
256+
b = torch.randn((128, 128), dtype=torch.float16, device="cuda")
257+
c = torch.randn((64, 128), dtype=torch.float32, device="cuda")
258+
259+
d_tri = torch.empty_like(c)
260+
compiled = kernel[(1, )](a, b, c, d_tri)
261+
262+
ttgir = compiled.asm["ttgir"]
263+
assert ttgir.count("ttng.tmem_alloc") == 3
264+
assert ttgir.count("ttng.tmem_alloc : () -> !ttg.memdesc<64x128xf32") == 1
265+
assert ttgir.count("ttng.tmem_alloc : () -> !ttg.memdesc<64x128xf16") == 2
266+
267+
llir = compiled.asm["llir"]
268+
assert llir.count("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$1], 128")
269+
270+
d_ref = a @ b + c
271+
torch.testing.assert_close(d_ref, d_tri, rtol=0.08, atol=0)

python/triton/experimental/gluon/language/_core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from triton._C.libtriton.gluon_ir import GluonOpBuilder
88
from ._semantic import GluonSemantic
99

10-
from ._layouts import SharedLayout, DistributedLayout
10+
from ._layouts import SharedLayout, DistributedLayout, AutoLayout
1111
from triton._C.libtriton import ir
1212
import triton.language.core as tl_core
1313
from triton.language.core import (
@@ -383,6 +383,8 @@ def convert_layout(value, layout, assert_trivial=False, _semantic=None):
383383
tensor: The tensor with the new layout.
384384
"""
385385
layout = _unwrap_if_constexpr(layout)
386+
if isinstance(value.type.layout, AutoLayout):
387+
return set_auto_layout(value, layout, _semantic=_semantic)
386388
return _semantic.convert_layout(value, layout, assert_trivial)
387389

388390

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ def get_tmem_32x32b_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_sp
7777
if M == 64:
7878
threads_per_warp = [16, 2]
7979
if num_blocks == 1:
80-
size_per_thread = [1, N // (num_warp_groups * 2)]
80+
size_per_thread = [1, triton.cdiv(N, num_warp_groups * 2)]
8181
warps_per_cta = [4, num_warp_groups]
8282
else:
83-
size_per_thread = [1, N // 2]
83+
size_per_thread = [1, triton.cdiv(N, 2)]
8484
warps_per_cta = [4 * min(blocks_per_tile[0], num_warp_groups)]
8585
warps_per_cta.append(triton.cdiv(num_warp_groups, warps_per_cta[0] // 4))
8686
else:
@@ -89,7 +89,7 @@ def get_tmem_32x32b_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_sp
8989
threads_per_warp = [32, 1]
9090
warps_per_cta = [4 * num_warp_groups, 1]
9191
else:
92-
size_per_thread = [1, N // num_warp_groups]
92+
size_per_thread = [1, triton.cdiv(N, num_warp_groups)]
9393
threads_per_warp = [32, 1]
9494
warps_per_cta = [4, num_warp_groups]
9595
return BlockedLayout(

0 commit comments

Comments
 (0)