Skip to content

Commit 0173f75

Browse files
authored
[AMD] Add Tests for MXFP GEMM Gluon Kernel for GFX1250 (#8371)
This PR added tests for MXFP GEMM Gluon Kernel for GFX1250.
1 parent d5f3f23 commit 0173f75

File tree

2 files changed

+191
-2
lines changed

2 files changed

+191
-2
lines changed

python/triton/experimental/gluon/language/amd/gfx1250/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def wmma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None)
5959
"accumulator tensor's layout must be (16, 16, 128)"
6060

6161
# TODO: Add more formats
62-
assert a_format.value in {"e2m1"}, f"Unsupported lhs_format: {a_format.value}"
63-
assert b_format.value in {"e2m1"}, f"Unsupported rhs_format: {b_format.value}"
62+
assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}"
63+
assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}"
6464

6565
assert a_scale is not None and b_scale is not None, "Scales must not be None"
6666

third_party/amd/python/test/test_gluon_gfx1250.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,192 @@ def test_runtime_tensor_copy(BLOCK_M, BLOCK_N):
422422

423423
b_triton = b_device.cpu()
424424
assert torch.equal(b_triton, a)
425+
426+
427+
@gluon.jit
428+
def mxgemm_kernel(a_ptr, b_ptr, c_ptr, a_scale, b_scale, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm,
429+
stride_cn, stride_scale, DTYPE_A: ttgl.constexpr, DTYPE_B: ttgl.constexpr,
430+
SCALE_BLOCK: ttgl.constexpr, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr,
431+
BLOCK_K: ttgl.constexpr, GROUP_SIZE_M: ttgl.constexpr):
432+
DIV_FACTOR_A: ttgl.constexpr = 2 if DTYPE_A == "e2m1" else 1
433+
DIV_FACTOR_B: ttgl.constexpr = 2 if DTYPE_B == "e2m1" else 1
434+
BLOCK_K_SCALE: ttgl.constexpr = BLOCK_K // SCALE_BLOCK
435+
BLOCK_K_PACKED_A: ttgl.constexpr = BLOCK_K // DIV_FACTOR_A
436+
BLOCK_K_PACKED_B: ttgl.constexpr = BLOCK_K // DIV_FACTOR_B
437+
438+
BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [8, 4], [4, 1], [1, 0])
439+
A_BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 16], [8, 4], [4, 1], [1, 0])
440+
B_BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 16], [16, 2], [4, 1], [1, 0])
441+
442+
WMMA_LAYOUT: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, transposed=True, warps_per_cta=[2, 2],
443+
instr_shape=[16, 16, 128])
444+
WMMA_LAYOUT_PACKED: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, transposed=True, warps_per_cta=[2, 2],
445+
instr_shape=[16, 16, 64])
446+
A_SCALE_LINEAR_LAYOUT: ttgl.constexpr = ttgl.DistributedLinearLayout(
447+
reg_bases=[[0, 1], [0, 2]], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp_bases=[[0, 0], [16, 0]],
448+
block_bases=[], shape=[32, 4])
449+
B_SCALE_LINEAR_LAYOUT: ttgl.constexpr = ttgl.DistributedLinearLayout(
450+
reg_bases=[[0, 1], [0, 2]], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp_bases=[[16, 0], [0, 0]],
451+
block_bases=[], shape=[32, 4])
452+
453+
DOT_LAYOUT_A: ttgl.constexpr = ttgl.DotOperandLayout(
454+
operand_index=0, parent=WMMA_LAYOUT_PACKED if DTYPE_A == "e2m1" else WMMA_LAYOUT, k_width=16)
455+
DOT_LAYOUT_B: ttgl.constexpr = ttgl.DotOperandLayout(
456+
operand_index=1, parent=WMMA_LAYOUT_PACKED if DTYPE_B == "e2m1" else WMMA_LAYOUT, k_width=16)
457+
458+
pid = ttgl.program_id(axis=0)
459+
num_pid_m = ttgl.cdiv(M, BLOCK_M)
460+
num_pid_n = ttgl.cdiv(N, BLOCK_N)
461+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
462+
group_id = pid // num_pid_in_group
463+
first_pid_m = group_id * GROUP_SIZE_M
464+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
465+
pid_m = first_pid_m + (pid % group_size_m)
466+
pid_n = (pid % num_pid_in_group) // group_size_m
467+
468+
offs_am = (pid_m * BLOCK_M + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, A_BLOCKED_LAYOUT))) % M
469+
offs_ak = ttgl.arange(0, BLOCK_K_PACKED_A, layout=ttgl.SliceLayout(0, A_BLOCKED_LAYOUT))
470+
offs_bk = ttgl.arange(0, BLOCK_K_PACKED_B, layout=ttgl.SliceLayout(1, B_BLOCKED_LAYOUT))
471+
offs_bn = (pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, B_BLOCKED_LAYOUT))) % N
472+
473+
offs_scale_am = (pid_m * BLOCK_M + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT))) % M
474+
offs_scale_ak = ttgl.arange(0, BLOCK_K_SCALE, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT))
475+
offs_scale_bn = (pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT))) % N
476+
offs_scale_bk = ttgl.arange(0, BLOCK_K_SCALE, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT))
477+
478+
a_scale_ptr = a_scale + offs_scale_am[:, None] * stride_scale + offs_scale_ak[None, :]
479+
b_scale_ptr = b_scale + offs_scale_bn[:, None] * stride_scale + offs_scale_bk[None, :]
480+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak)
481+
b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
482+
483+
accumulator = ttgl.zeros((BLOCK_M, BLOCK_N), dtype=ttgl.float32, layout=WMMA_LAYOUT)
484+
for k in range(0, ttgl.cdiv(K, BLOCK_K)):
485+
k_remaining_a = K - k * BLOCK_K_PACKED_A
486+
k_remaining_b = K - k * BLOCK_K_PACKED_B
487+
valid_k_a = offs_ak < k_remaining_a
488+
valid_k_b = offs_bk < k_remaining_b
489+
490+
scale_a = ttgl.load(a_scale_ptr)
491+
scale_b = ttgl.load(b_scale_ptr)
492+
scale_a = ttgl.convert_layout(scale_a, A_SCALE_LINEAR_LAYOUT)
493+
scale_b = ttgl.convert_layout(scale_b, B_SCALE_LINEAR_LAYOUT)
494+
495+
a = ttgl.load(a_ptrs, mask=valid_k_a[None, :], other=0.0)
496+
b = ttgl.load(b_ptrs, mask=valid_k_b[:, None], other=0.0)
497+
a = ttgl.convert_layout(a, DOT_LAYOUT_A)
498+
b = ttgl.convert_layout(b, DOT_LAYOUT_B)
499+
500+
accumulator = ttgl.amd.gfx1250.wmma_scaled(a, scale_a, DTYPE_A, b, scale_b, DTYPE_B, accumulator)
501+
502+
a_ptrs += BLOCK_K_PACKED_A * stride_ak
503+
b_ptrs += BLOCK_K_PACKED_B * stride_bk
504+
505+
a_scale_ptr += BLOCK_K_SCALE
506+
b_scale_ptr += BLOCK_K_SCALE
507+
508+
offs_cm = pid_m * BLOCK_M + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, WMMA_LAYOUT))
509+
offs_cn = pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, WMMA_LAYOUT))
510+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
511+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
512+
ttgl.store(c_ptrs, accumulator, mask=c_mask)
513+
514+
515+
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 64), (32, 32, 128)])
516+
@pytest.mark.parametrize("DTYPE_A", ["float8_e5m2", "float8_e4m3", "float4"])
517+
@pytest.mark.parametrize("DTYPE_B", ["float8_e5m2", "float8_e4m3", "float4"])
518+
def test_compile_mxgemm(BLOCK_M, BLOCK_N, BLOCK_K, DTYPE_A, DTYPE_B):
519+
scale_block = 32
520+
521+
if BLOCK_K < 128:
522+
pytest.skip("NYI: don't support block shape smaller than instr shape")
523+
524+
triton_dtype_converter = {'float8_e5m2': "fp8e5", "float8_e4m3": "fp8e4nv", "float4": "u8"}
525+
dot_scaled_dtype_converter = {'float8_e5m2': "e5m2", "float8_e4m3": "e4m3", "float4": "e2m1"}
526+
527+
k = triton.compile(
528+
gluon._runtime.GluonASTSource(
529+
fn=mxgemm_kernel, signature={
530+
"a_ptr": f"*{triton_dtype_converter[DTYPE_A]}", "b_ptr": f"*{triton_dtype_converter[DTYPE_B]}", "c_ptr":
531+
"*fp32", "a_scale": "*u8", "b_scale": "*u8", "M": "i32", "N": "i32", "K": "i32", "stride_am": "i32",
532+
"stride_ak": "i32", "stride_bk": "i32", "stride_bn": "i32", "stride_cm": "i32", "stride_cn": "i32",
533+
"stride_scale": "i32", "DTYPE_A": "constexpr", "DTYPE_B": "constexpr", "SCALE_BLOCK": "constexpr",
534+
"BLOCK_M": "constexpr", "BLOCK_N": "constexpr", "BLOCK_K": "constexpr", "GROUP_SIZE_M": "constexpr"
535+
}, constexprs={
536+
"DTYPE_A": dot_scaled_dtype_converter[DTYPE_A], "DTYPE_B": dot_scaled_dtype_converter[DTYPE_B],
537+
"SCALE_BLOCK": scale_block, "BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K, "GROUP_SIZE_M":
538+
1
539+
}), target=GPUTarget("hip", 'gfx1250', 32))
540+
541+
amdgcn = k.asm["amdgcn"]
542+
pattern = "v_wmma_scale_f32_16x16x128_f8f6f4"
543+
assert re.search(pattern, amdgcn), f"Can't find instruction {pattern} in AMDGCN assembly"
544+
545+
546+
@pytest.mark.parametrize("M, N, K", [(32, 32, 128), (128, 128, 512), (1, 8192, 512)])
547+
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 128), (64, 64, 128)])
548+
@pytest.mark.parametrize("DTYPE_A", ["float8_e5m2", "float8_e4m3", "float4"])
549+
@pytest.mark.parametrize("DTYPE_B", ["float8_e5m2", "float8_e4m3", "float4"])
550+
def test_runtime_mxgemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, DTYPE_A, DTYPE_B):
551+
scale_block = 32
552+
553+
torch.manual_seed(0)
554+
555+
def torch_gemm_mxfp(a, b, a_scale, b_scale, scale_block, M, N, K):
556+
a_scale_f32 = a_scale.to(torch.float32).repeat_interleave(scale_block, dim=1)[:M, :K]
557+
b_scale_f32 = b_scale.to(torch.float32).repeat_interleave(scale_block, dim=1).T.contiguous()[:K, :N]
558+
559+
a_f32 = a.to(torch.float32)
560+
b_f32 = b.to(torch.float32)
561+
562+
return torch.matmul(a_f32 * a_scale_f32, b_f32 * b_scale_f32).to(torch.float32)
563+
564+
def init_data(dtype, d0: int, d1: int):
565+
if dtype == 'float4':
566+
return MXFP4Tensor(size=(d0, d1)).random()
567+
elif dtype == "float8_e5m2":
568+
return torch.randint(20, 40, (d0, d1), dtype=torch.uint8).view(torch.float8_e5m2)
569+
elif dtype == "float8_e4m3":
570+
return torch.randint(20, 40, (d0, d1), dtype=torch.uint8).view(torch.float8_e4m3fn)
571+
else:
572+
raise NotImplementedError(f"NYI: unsupported dtype: {dtype}")
573+
574+
a = init_data(DTYPE_A, M, K)
575+
b = init_data(DTYPE_B, K, N)
576+
a_size = (M, (K + scale_block - 1) // scale_block)
577+
b_size = (N, (K + scale_block - 1) // scale_block)
578+
a_scale = MXScaleTensor(size=a_size).random(low=1.0, high=32.0)
579+
b_scale = MXScaleTensor(size=b_size).random(low=1.0, high=32.0)
580+
581+
c_ref = torch_gemm_mxfp(a, b, a_scale, b_scale, scale_block, M, N, K)
582+
583+
a_scale = a_scale.data
584+
b_scale = b_scale.data
585+
586+
# mxfp4 input needs packed along the k dim, i.e., two mxfp4 are packed in one uint8
587+
if DTYPE_A in ['float4', 'float6_e2m3', 'float6_e3m2']:
588+
a = a.to_packed_tensor(dim=1)
589+
if DTYPE_B in ['float4', 'float6_e2m3', 'float6_e3m2']:
590+
b = b.to_packed_tensor(dim=0)
591+
592+
c_d = torch.zeros(M, N, dtype=torch.float32).cuda()
593+
a_d = a.data.contiguous().cuda()
594+
b_d = b.data.contiguous().cuda()
595+
a_scale_d = a_scale.cuda()
596+
b_scale_d = b_scale.cuda()
597+
598+
stride_am, stride_ak = a_d.stride(0), a_d.stride(1)
599+
stride_bk, stride_bn = b_d.stride(0), b_d.stride(1)
600+
stride_cm, stride_cn = c_d.stride(0), c_d.stride(1)
601+
stride_scale = a_scale_d.stride(0)
602+
603+
numBlocks = triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)
604+
grid = [numBlocks, 1, 1]
605+
group_size_m = 1
606+
607+
dtype_converter = {'float8_e5m2': "e5m2", "float8_e4m3": "e4m3", "float4": "e2m1"}
608+
609+
mxgemm_kernel[grid](a_d, b_d, c_d, a_scale_d, b_scale_d, M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
610+
stride_cm, stride_cn, stride_scale, dtype_converter[DTYPE_A], dtype_converter[DTYPE_B],
611+
scale_block, BLOCK_M, BLOCK_N, BLOCK_K, group_size_m, num_warps=4, num_ctas=1)
612+
613+
torch.testing.assert_close(c_d.cpu(), c_ref.cpu(), rtol=1e-5, atol=1e-8)

0 commit comments

Comments
 (0)