Skip to content

Commit 09f1aa4

Browse files
authored
[Gluon] Add helper and excution test for mma_scaled op (#8410)
Also add device_print
1 parent b611ccd commit 09f1aa4

File tree

5 files changed

+154
-5
lines changed

5 files changed

+154
-5
lines changed

python/test/gluon/test_core.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
TensorMemoryScalesLayout,
3030
allocate_tensor_memory,
3131
get_tmem_32x32b_reg_layout,
32+
get_tmem_scales_reg_layout,
3233
tcgen05_mma,
34+
tcgen05_mma_scaled,
3335
tcgen05_commit,
3436
tcgen05_copy,
3537
float2,
@@ -1329,3 +1331,92 @@ def kernel_auto_layout_constant(threads_per_warp: ttgl.constexpr):
13291331

13301332
def test_auto_layout_constant():
13311333
kernel_auto_layout_constant.warmup(THREADS_PER_WARP, grid=(1, ))
1334+
1335+
1336+
def fp8e8m0_to_float32(scale):
1337+
scale = scale.view(torch.uint8)
1338+
scale = scale.to(torch.int32)
1339+
scale = scale << 23
1340+
scale = scale.view(torch.float32)
1341+
return scale
1342+
1343+
1344+
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
1345+
def test_tcgen05_mma_scaled_minimal():
1346+
M = 128
1347+
N = 128
1348+
K = 128
1349+
threads_per_warp = ttgl.constexpr(THREADS_PER_WARP)
1350+
1351+
@gluon.jit
1352+
def kernel(out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, a, b, a_scale, b_scale):
1353+
# Simple register layout for creating constants and storing results
1354+
reg_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [threads_per_warp, 1], [ttgl.num_warps(), 1], [1, 0])
1355+
1356+
# Shared-memory layouts for MMA operands
1357+
nvmma_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, transposed=False,
1358+
element_bitwidth=8, rank=2)
1359+
# Allocate zero operands in shared memory (values don't matter since scales are zero)
1360+
block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], warps_per_cta=[ttgl.num_warps(), 1],
1361+
order=[1, 0])
1362+
a_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, block_layout))[:, None]
1363+
a_offs_k = ttgl.arange(0, K, layout=ttgl.SliceLayout(0, block_layout))[None, :]
1364+
b_offs_k = ttgl.arange(0, K, layout=ttgl.SliceLayout(1, block_layout))[:, None]
1365+
b_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, block_layout))[None, :]
1366+
1367+
a_tile = ttgl.load(a + a_offs_m * K + a_offs_k)
1368+
b_tile = ttgl.load(b + b_offs_k * N + b_offs_n)
1369+
a_smem = ttgl.allocate_shared_memory(ttgl.float8e5, [M, K], nvmma_layout, a_tile)
1370+
b_smem = ttgl.allocate_shared_memory(ttgl.float8e5, [K, N], nvmma_layout, b_tile)
1371+
1372+
# Accumulator in TMEM initialized to ones
1373+
acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([M, N], col_stride=1)
1374+
tmem_reg_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(M, N, [M, N], ttgl.num_warps())
1375+
acc_init = ttgl.zeros([M, N], ttgl.float32, layout=tmem_reg_layout)
1376+
acc_tmem = allocate_tensor_memory(ttgl.float32, [M, N], acc_tmem_layout, acc_init)
1377+
1378+
# Zero scales in TMEM
1379+
scale_layout: ttgl.constexpr = TensorMemoryScalesLayout()
1380+
scale_reg_layout: ttgl.constexpr = get_tmem_scales_reg_layout(M, N, [M, N], ttgl.num_warps())
1381+
scale_offs_k = ttgl.arange(0, (K // 32), layout=ttgl.SliceLayout(0, scale_reg_layout))[None, :]
1382+
scale_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, scale_reg_layout))[:, None]
1383+
scale_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(1, scale_reg_layout))[:, None]
1384+
a_scale_init = ttgl.load(a_scale + scale_offs_m * (K // 32) + scale_offs_k)
1385+
b_scale_init = ttgl.load(b_scale + scale_offs_n * (K // 32) + scale_offs_k)
1386+
a_scale_tmem = allocate_tensor_memory(ttgl.int8, [M, K // 32], scale_layout, a_scale_init)
1387+
b_scale_tmem = allocate_tensor_memory(ttgl.int8, [M, K // 32], scale_layout, b_scale_init)
1388+
1389+
# Issue a single scaled MMA and commit
1390+
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
1391+
mbarrier.init(bar, count=1)
1392+
tcgen05_mma_scaled(a_smem, b_smem, acc_tmem, a_scale_tmem, b_scale_tmem, "e5m2", "e5m2", use_acc=True)
1393+
tcgen05_commit(bar)
1394+
mbarrier.wait(bar, phase=0)
1395+
1396+
# Load result from TMEM and store to global
1397+
out_reg = acc_tmem.load(tmem_reg_layout)
1398+
store_layout: ttgl.constexpr = reg_layout
1399+
offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, store_layout))[:, None]
1400+
offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, store_layout))[None, :]
1401+
offs = offs_m * N + offs_n
1402+
ttgl.store(out_ptr + offs, ttgl.convert_layout(out_reg, store_layout))
1403+
1404+
out = torch.empty((M, N), dtype=torch.float32, device="cuda")
1405+
a = torch.randint(20, 40, (M, K), dtype=torch.uint8, device="cuda").view(torch.float8_e5m2)
1406+
b = torch.randint(20, 40, (K, N), dtype=torch.uint8, device="cuda").view(torch.float8_e5m2)
1407+
a_scale = torch.randint(64, 130, (M, K // 32), dtype=torch.uint8, device="cuda")
1408+
b_scale = torch.randint(64, 130, (N, K // 32), dtype=torch.uint8, device="cuda")
1409+
compiled = kernel[(1, )](out, M, N, K, a, b, a_scale, b_scale)
1410+
A = a.to(torch.float32)
1411+
B = b.to(torch.float32)
1412+
a_scale_f32 = fp8e8m0_to_float32(a_scale)
1413+
b_scale_f32 = fp8e8m0_to_float32(b_scale)
1414+
a_scale_f32 = a_scale_f32.repeat_interleave(32, dim=1)
1415+
b_scale_f32 = b_scale_f32.repeat_interleave(32, dim=1)
1416+
b_scale_f32 = b_scale_f32.T.contiguous()
1417+
A = A * a_scale_f32
1418+
B = B * b_scale_f32
1419+
ref = torch.matmul(A, B)
1420+
torch.testing.assert_close(out, ref, atol=1e-6, rtol=1e-6)
1421+
ttgir = compiled.asm["ttgir"]
1422+
assert "ttng.tc_gen5_mma_scaled" in ttgir

python/test/gluon/test_frontend.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,25 +77,27 @@ def test_convert_layout(target):
7777

7878

7979
@gluon.jit
80-
def assume_kernel(arg: tl.int32):
80+
def simple_ops_kernel(arg: tl.int32):
8181
ttgl.assume(arg > 1)
82+
ttgl.device_print("arg: ", arg)
8283

8384

8485
@pytest.mark.parametrize("target", ALL_TARGETS)
85-
def test_assume(target):
86+
def test_simple_ops(target):
8687
arg = 100
8788
mod = run_parser(
88-
assume_kernel,
89+
simple_ops_kernel,
8990
*make_args(arg),
9091
target=target,
9192
)
9293
expecttest.assert_expected_inline(
9394
anonymize_ir(mod.str_nodebug()), """\
9495
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
95-
tt.func public @assume_kernel(%arg0: i32) attributes {noinline = false} {
96+
tt.func public @simple_ops_kernel(%arg0: i32) attributes {noinline = false} {
9697
%c1_i32 = arith.constant 1 : i32
9798
%0 = arith.cmpi sgt, %arg0, %c1_i32 : i32
9899
llvm.intr.assume %0 : i1
100+
tt.print " arg: " {hex = false, isSigned = array<i32: 1>} : %arg0 : i32
99101
tt.return
100102
}
101103
}

python/triton/experimental/gluon/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
bank_conflicts,
4747
convert_layout,
4848
device_assert,
49+
device_print,
4950
dot_fma,
5051
expand_dims,
5152
full,

python/triton/experimental/gluon/language/_core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def wrapper(*args, **kwargs):
111111
atomic_xor = builtin(tl_core.atomic_xor)
112112
broadcast = builtin(tl_core.broadcast)
113113
device_assert = builtin(tl_core.device_assert)
114+
device_print = builtin(tl_core.device_print)
114115
expand_dims = builtin(tl_core.expand_dims)
115116
inline_asm_elementwise = builtin(tl_core.inline_asm_elementwise)
116117
join = builtin(tl_core.join)

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from triton.runtime.jit import constexpr_function
66
from triton.experimental.gluon.language import _core as ttgl
77
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
8-
from triton.experimental.gluon.language._layouts import BlockedLayout, _get_shape_per_cta
8+
from triton.experimental.gluon.language._layouts import BlockedLayout, _get_shape_per_cta, DistributedLinearLayout
99
from triton.experimental.gluon.language._semantic import _check
1010

1111
from . import tma
@@ -22,6 +22,7 @@
2222
"async_copy",
2323
"fence_async_shared",
2424
"get_tmem_32x32b_reg_layout",
25+
"get_tmem_scales_reg_layout",
2526
"mbarrier",
2627
"mma_v2",
2728
"tensor_memory_descriptor",
@@ -135,6 +136,59 @@ def get_tmem_32x32b_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_sp
135136
)
136137

137138

139+
@constexpr_function
140+
def get_tmem_scales_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_split_num=None, cta_order=None):
141+
"""Return a linear layout that is compatible with tmem scaled layout.
142+
"""
143+
assert len(shape) == 2, "expected a 2D tensor"
144+
assert num_warps in [4, 8], "expected 4 or 8 warps"
145+
146+
# Use per-CTA shape to build the linear layout bases
147+
shape_per_cta = _get_shape_per_cta(shape, cta_split_num)
148+
M_cta, N_cta = shape_per_cta[0], shape_per_cta[1]
149+
150+
# Register bases: pack 4 scales together along N; if fewer than 4, replicate.
151+
reg_bases = []
152+
i = 1
153+
while i < 4:
154+
if i >= N_cta:
155+
reg_bases.append([0, 0])
156+
else:
157+
reg_bases.append([0, i])
158+
i <<= 1
159+
160+
# Lane bases: distribute 32 rows of M along a warp.
161+
lane_bases = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]]
162+
163+
# Warp bases: replicate across warps within a warpgroup by default.
164+
warp_bases = [[0, 0], [0, 0]]
165+
166+
# Extend register bases for larger M and N beyond the initial pack.
167+
i = 32
168+
while i < M_cta:
169+
reg_bases.append([i, 0])
170+
i <<= 1
171+
172+
i = 4
173+
while i < N_cta:
174+
reg_bases.append([0, i])
175+
i <<= 1
176+
177+
# For 8 warps, distribute the last dimension on the second warpgoup.
178+
if num_warps == 8:
179+
warp_bases.append(reg_bases[-1])
180+
reg_bases.pop()
181+
182+
# No explicit CTA mapping here; the register layout is per-CTA.
183+
return DistributedLinearLayout(
184+
reg_bases=reg_bases,
185+
lane_bases=lane_bases,
186+
warp_bases=warp_bases,
187+
block_bases=[],
188+
shape=shape_per_cta,
189+
)
190+
191+
138192
class tensor_memory_descriptor_type(base_type):
139193

140194
def __init__(self, element_ty, shape, layout, alloc_shape):

0 commit comments

Comments
 (0)