|
29 | 29 | TensorMemoryScalesLayout, |
30 | 30 | allocate_tensor_memory, |
31 | 31 | get_tmem_32x32b_reg_layout, |
| 32 | + get_tmem_scales_reg_layout, |
32 | 33 | tcgen05_mma, |
| 34 | + tcgen05_mma_scaled, |
33 | 35 | tcgen05_commit, |
34 | 36 | tcgen05_copy, |
35 | 37 | float2, |
@@ -1329,3 +1331,92 @@ def kernel_auto_layout_constant(threads_per_warp: ttgl.constexpr): |
1329 | 1331 |
|
1330 | 1332 | def test_auto_layout_constant(): |
1331 | 1333 | kernel_auto_layout_constant.warmup(THREADS_PER_WARP, grid=(1, )) |
| 1334 | + |
| 1335 | + |
| 1336 | +def fp8e8m0_to_float32(scale): |
| 1337 | + scale = scale.view(torch.uint8) |
| 1338 | + scale = scale.to(torch.int32) |
| 1339 | + scale = scale << 23 |
| 1340 | + scale = scale.view(torch.float32) |
| 1341 | + return scale |
| 1342 | + |
| 1343 | + |
| 1344 | +@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") |
| 1345 | +def test_tcgen05_mma_scaled_minimal(): |
| 1346 | + M = 128 |
| 1347 | + N = 128 |
| 1348 | + K = 128 |
| 1349 | + threads_per_warp = ttgl.constexpr(THREADS_PER_WARP) |
| 1350 | + |
| 1351 | + @gluon.jit |
| 1352 | + def kernel(out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, a, b, a_scale, b_scale): |
| 1353 | + # Simple register layout for creating constants and storing results |
| 1354 | + reg_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [threads_per_warp, 1], [ttgl.num_warps(), 1], [1, 0]) |
| 1355 | + |
| 1356 | + # Shared-memory layouts for MMA operands |
| 1357 | + nvmma_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, transposed=False, |
| 1358 | + element_bitwidth=8, rank=2) |
| 1359 | + # Allocate zero operands in shared memory (values don't matter since scales are zero) |
| 1360 | + block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], warps_per_cta=[ttgl.num_warps(), 1], |
| 1361 | + order=[1, 0]) |
| 1362 | + a_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, block_layout))[:, None] |
| 1363 | + a_offs_k = ttgl.arange(0, K, layout=ttgl.SliceLayout(0, block_layout))[None, :] |
| 1364 | + b_offs_k = ttgl.arange(0, K, layout=ttgl.SliceLayout(1, block_layout))[:, None] |
| 1365 | + b_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, block_layout))[None, :] |
| 1366 | + |
| 1367 | + a_tile = ttgl.load(a + a_offs_m * K + a_offs_k) |
| 1368 | + b_tile = ttgl.load(b + b_offs_k * N + b_offs_n) |
| 1369 | + a_smem = ttgl.allocate_shared_memory(ttgl.float8e5, [M, K], nvmma_layout, a_tile) |
| 1370 | + b_smem = ttgl.allocate_shared_memory(ttgl.float8e5, [K, N], nvmma_layout, b_tile) |
| 1371 | + |
| 1372 | + # Accumulator in TMEM initialized to ones |
| 1373 | + acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([M, N], col_stride=1) |
| 1374 | + tmem_reg_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(M, N, [M, N], ttgl.num_warps()) |
| 1375 | + acc_init = ttgl.zeros([M, N], ttgl.float32, layout=tmem_reg_layout) |
| 1376 | + acc_tmem = allocate_tensor_memory(ttgl.float32, [M, N], acc_tmem_layout, acc_init) |
| 1377 | + |
| 1378 | + # Zero scales in TMEM |
| 1379 | + scale_layout: ttgl.constexpr = TensorMemoryScalesLayout() |
| 1380 | + scale_reg_layout: ttgl.constexpr = get_tmem_scales_reg_layout(M, N, [M, N], ttgl.num_warps()) |
| 1381 | + scale_offs_k = ttgl.arange(0, (K // 32), layout=ttgl.SliceLayout(0, scale_reg_layout))[None, :] |
| 1382 | + scale_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, scale_reg_layout))[:, None] |
| 1383 | + scale_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(1, scale_reg_layout))[:, None] |
| 1384 | + a_scale_init = ttgl.load(a_scale + scale_offs_m * (K // 32) + scale_offs_k) |
| 1385 | + b_scale_init = ttgl.load(b_scale + scale_offs_n * (K // 32) + scale_offs_k) |
| 1386 | + a_scale_tmem = allocate_tensor_memory(ttgl.int8, [M, K // 32], scale_layout, a_scale_init) |
| 1387 | + b_scale_tmem = allocate_tensor_memory(ttgl.int8, [M, K // 32], scale_layout, b_scale_init) |
| 1388 | + |
| 1389 | + # Issue a single scaled MMA and commit |
| 1390 | + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) |
| 1391 | + mbarrier.init(bar, count=1) |
| 1392 | + tcgen05_mma_scaled(a_smem, b_smem, acc_tmem, a_scale_tmem, b_scale_tmem, "e5m2", "e5m2", use_acc=True) |
| 1393 | + tcgen05_commit(bar) |
| 1394 | + mbarrier.wait(bar, phase=0) |
| 1395 | + |
| 1396 | + # Load result from TMEM and store to global |
| 1397 | + out_reg = acc_tmem.load(tmem_reg_layout) |
| 1398 | + store_layout: ttgl.constexpr = reg_layout |
| 1399 | + offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, store_layout))[:, None] |
| 1400 | + offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, store_layout))[None, :] |
| 1401 | + offs = offs_m * N + offs_n |
| 1402 | + ttgl.store(out_ptr + offs, ttgl.convert_layout(out_reg, store_layout)) |
| 1403 | + |
| 1404 | + out = torch.empty((M, N), dtype=torch.float32, device="cuda") |
| 1405 | + a = torch.randint(20, 40, (M, K), dtype=torch.uint8, device="cuda").view(torch.float8_e5m2) |
| 1406 | + b = torch.randint(20, 40, (K, N), dtype=torch.uint8, device="cuda").view(torch.float8_e5m2) |
| 1407 | + a_scale = torch.randint(64, 130, (M, K // 32), dtype=torch.uint8, device="cuda") |
| 1408 | + b_scale = torch.randint(64, 130, (N, K // 32), dtype=torch.uint8, device="cuda") |
| 1409 | + compiled = kernel[(1, )](out, M, N, K, a, b, a_scale, b_scale) |
| 1410 | + A = a.to(torch.float32) |
| 1411 | + B = b.to(torch.float32) |
| 1412 | + a_scale_f32 = fp8e8m0_to_float32(a_scale) |
| 1413 | + b_scale_f32 = fp8e8m0_to_float32(b_scale) |
| 1414 | + a_scale_f32 = a_scale_f32.repeat_interleave(32, dim=1) |
| 1415 | + b_scale_f32 = b_scale_f32.repeat_interleave(32, dim=1) |
| 1416 | + b_scale_f32 = b_scale_f32.T.contiguous() |
| 1417 | + A = A * a_scale_f32 |
| 1418 | + B = B * b_scale_f32 |
| 1419 | + ref = torch.matmul(A, B) |
| 1420 | + torch.testing.assert_close(out, ref, atol=1e-6, rtol=1e-6) |
| 1421 | + ttgir = compiled.asm["ttgir"] |
| 1422 | + assert "ttng.tc_gen5_mma_scaled" in ttgir |
0 commit comments