Skip to content

Commit 22b1a44

Browse files
authored
[AMD] Support Scale Preshuffling in Decomposed Scaled Dot (#8170)
This PR added support for scale preshuffling for decomposed scaled dot.
1 parent 3634051 commit 22b1a44

File tree

2 files changed

+255
-165
lines changed

2 files changed

+255
-165
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 145 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,9 @@ def block_scale_mxfp_matmul( #
421421
stride_cm, stride_cn, #
422422
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
423423
NUM_STAGES: tl.constexpr, USE_2D_SCALE_LOAD: tl.constexpr):
424-
## This kernel assumes a_scale and b_scale are coming in with shapes
425-
## [BLOCK_M(or N) // 128, BLOCK_K // 128, 32, 4, 4] for optimial performance
426-
## on nvidia sm100+ HW
424+
# This kernel assumes a_scale and b_scale are coming in with shapes
425+
# [BLOCK_M(or N) // 128, BLOCK_K // 128, 32, 4, 4] for optimial performance
426+
# on nvidia sm100+ HW
427427
pid = tl.program_id(axis=0)
428428
num_pid_m = tl.cdiv(M, BLOCK_M)
429429
pid_m = pid % num_pid_m
@@ -482,18 +482,21 @@ def block_scale_mxfp_matmul( #
482482

483483

484484
@triton.jit
485-
def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scales_ptr, b_scales_ptr, M, N, K, stride_am,
486-
stride_ak, stride_bk, stride_bn, stride_ck, stride_cm, stride_cn,
487-
stride_asm, stride_ask, stride_bsn, stride_bsk,
488-
# Meta-parameters
489-
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
490-
mfma_nonkdim: tl.constexpr, preshuffle: tl.constexpr):
485+
def _gemm_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scales_ptr, b_scales_ptr, M, N, K, stride_am,
486+
stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_asm, stride_ask,
487+
stride_bsn, stride_bsk,
488+
# Meta-parameters
489+
DTYPE_A: tl.constexpr, DTYPE_B: tl.constexpr, BLOCK_M: tl.constexpr,
490+
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, mfma_nonkdim: tl.constexpr,
491+
preshuffle: tl.constexpr, fast_math: tl.constexpr):
491492
"""Kernel for computing the matmul C = A x B.
492-
A and B inputs are in the microscale fp4 (mxfp4) format.
493493
A_scales and B_scales are in e8m0 format.
494494
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
495495
"""
496496

497+
PACK_FACTOR_A: tl.constexpr = 2 if DTYPE_A == "e2m1" else 1
498+
PACK_FACTOR_B: tl.constexpr = 2 if DTYPE_B == "e2m1" else 1
499+
497500
pid = tl.program_id(axis=0)
498501

499502
num_pid_n = tl.cdiv(N, BLOCK_N)
@@ -502,73 +505,99 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scale
502505

503506
# We assume 32 elements along K share the same scale.
504507
SCALE_GROUP_SIZE: tl.constexpr = 32
508+
MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // SCALE_GROUP_SIZE
505509

506510
if preshuffle:
507511
NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 32
508512
else:
509513
NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 1
510514

511-
num_k_iter = tl.cdiv(K, BLOCK_K // 2)
512515
# Create pointers for first block of A and B input matrices
513516
# The BLOCK sizes are of the elements and in fp4 we pack 2 per uint8 container.
514-
offs_k = tl.arange(0, BLOCK_K // 2)
515-
offs_k_split = offs_k
517+
offs_ak = tl.arange(0, BLOCK_K // PACK_FACTOR_A)
518+
offs_bk = tl.arange(0, BLOCK_K // PACK_FACTOR_B)
516519
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
517520
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
518-
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak)
519-
b_ptrs = b_ptr + (offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
521+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak)
522+
b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
520523

521524
# Create pointers for the first block of A and B scales
522-
offs_asn = (pid_n *
523-
(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE) + tl.arange(0, (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE))) % N
524-
offs_ks = tl.arange(0, BLOCK_K // SCALE_GROUP_SIZE * NON_K_PRESHUFFLE_BLOCK_SIZE)
525+
offs_ks = tl.arange(0, MX_SCALE_BLOCK_K * NON_K_PRESHUFFLE_BLOCK_SIZE)
525526

526527
# B scales are N x K even though B operand is K x N.
527-
b_scale_ptrs = (b_scales_ptr + offs_asn[:, None] * stride_bsn + offs_ks[None, :] * stride_bsk)
528-
offs_asm = (pid_m *
529-
(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE) + tl.arange(0, (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE))) % M
530-
a_scale_ptrs = (a_scales_ptr + offs_asm[:, None] * stride_asm + offs_ks[None, :] * stride_ask)
528+
if a_scales_ptr is not None:
529+
offs_asm = (pid_m *
530+
(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE) + tl.arange(0,
531+
(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE))) % M
532+
a_scale_ptrs = (a_scales_ptr + offs_asm[:, None] * stride_asm + offs_ks[None, :] * stride_ask)
533+
if b_scales_ptr is not None:
534+
offs_asn = (pid_n *
535+
(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE) + tl.arange(0,
536+
(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE))) % N
537+
b_scale_ptrs = (b_scales_ptr + offs_asn[:, None] * stride_bsn + offs_ks[None, :] * stride_bsk)
531538
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
532539

533-
for k in range(0, num_k_iter):
540+
for k in range(0, tl.cdiv(K, BLOCK_K)):
534541
if preshuffle:
535542
# Here we "undo" the shuffle done in global memory (shuffle_scales_cdna4 function).
536543
if mfma_nonkdim == 32:
537-
a_scales = tl.load(a_scale_ptrs).reshape(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE,
538-
BLOCK_K // SCALE_GROUP_SIZE // 8, 2, 32, 4,
539-
1).permute(0, 3, 1, 4, 2,
540-
5).reshape(BLOCK_M, BLOCK_K // SCALE_GROUP_SIZE)
541-
b_scales = tl.load(b_scale_ptrs).reshape(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE,
542-
BLOCK_K // SCALE_GROUP_SIZE // 8, 2, 32, 4,
543-
1).permute(0, 3, 1, 4, 2,
544-
5).reshape(BLOCK_N, BLOCK_K // SCALE_GROUP_SIZE)
544+
if a_scales_ptr is not None:
545+
a_scales = tl.load(a_scale_ptrs).reshape(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE,
546+
MX_SCALE_BLOCK_K // 8, 2, 32, 4,
547+
1).permute(0, 3, 1, 4, 2,
548+
5).reshape(BLOCK_M, MX_SCALE_BLOCK_K)
549+
else:
550+
a_scales = None
551+
if b_scales_ptr is not None:
552+
b_scales = tl.load(b_scale_ptrs).reshape(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE,
553+
MX_SCALE_BLOCK_K // 8, 2, 32, 4,
554+
1).permute(0, 3, 1, 4, 2,
555+
5).reshape(BLOCK_N, MX_SCALE_BLOCK_K)
556+
else:
557+
b_scales = None
545558
elif mfma_nonkdim == 16:
546-
a_scales = tl.load(a_scale_ptrs).reshape(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE,
547-
BLOCK_K // SCALE_GROUP_SIZE // 8, 4, 16, 2, 2,
548-
1).permute(0, 5, 3, 1, 4, 2,
549-
6).reshape(BLOCK_M, BLOCK_K // SCALE_GROUP_SIZE)
550-
b_scales = tl.load(b_scale_ptrs).reshape(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE,
551-
BLOCK_K // SCALE_GROUP_SIZE // 8, 4, 16, 2, 2,
552-
1).permute(0, 5, 3, 1, 4, 2,
553-
6).reshape(BLOCK_N, BLOCK_K // SCALE_GROUP_SIZE)
559+
if a_scales_ptr is not None:
560+
a_scales = tl.load(a_scale_ptrs).reshape(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE,
561+
MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2,
562+
1).permute(0, 5, 3, 1, 4, 2,
563+
6).reshape(BLOCK_M, MX_SCALE_BLOCK_K)
564+
else:
565+
a_scales = None
566+
if b_scales_ptr is not None:
567+
b_scales = tl.load(b_scale_ptrs).reshape(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE,
568+
MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2,
569+
1).permute(0, 5, 3, 1, 4, 2,
570+
6).reshape(BLOCK_N, MX_SCALE_BLOCK_K)
571+
else:
572+
b_scales = None
554573
else:
555-
a_scales = tl.load(a_scale_ptrs)
556-
b_scales = tl.load(b_scale_ptrs)
574+
if a_scales_ptr is not None:
575+
a_scales = tl.load(a_scale_ptrs)
576+
else:
577+
a_scales = None
578+
if b_scales_ptr is not None:
579+
b_scales = tl.load(b_scale_ptrs)
580+
else:
581+
b_scales = None
557582

558583
a = tl.load(a_ptrs)
559584
b = tl.load(b_ptrs, cache_modifier=None)
560585

561-
accumulator += tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1")
586+
accumulator += tl.dot_scaled(a, a_scales, DTYPE_A, b, b_scales, DTYPE_B, fast_math=fast_math)
562587

563588
# Advance the ptrs to the next K block.
564-
a_ptrs += (BLOCK_K // 2) * stride_ak
565-
b_ptrs += (BLOCK_K // 2) * stride_bk
589+
a_ptrs += (BLOCK_K // PACK_FACTOR_A) * stride_ak
590+
b_ptrs += (BLOCK_K // PACK_FACTOR_B) * stride_bk
566591
if preshuffle:
567-
a_scale_ptrs += BLOCK_K * stride_ask
568-
b_scale_ptrs += BLOCK_K * stride_bsk
592+
if a_scales_ptr is not None:
593+
a_scale_ptrs += BLOCK_K * stride_ask
594+
if b_scales_ptr is not None:
595+
b_scale_ptrs += BLOCK_K * stride_bsk
569596
else:
570-
a_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE) * stride_ask
571-
b_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE) * stride_bsk
597+
if a_scales_ptr is not None:
598+
a_scale_ptrs += MX_SCALE_BLOCK_K * stride_ask
599+
if b_scales_ptr is not None:
600+
b_scale_ptrs += MX_SCALE_BLOCK_K * stride_bsk
572601

573602
c = accumulator.to(c_ptr.type.element_ty)
574603

@@ -583,11 +612,14 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scale
583612

584613
@pytest.mark.parametrize("M, N, K", [(1024, 1024, 1024)])
585614
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 256), (64, 64, 512), [32, 32, 64]])
615+
@pytest.mark.parametrize("DTYPE_A, DTYPE_B, FAST_MATH", [("mxfp4", "mxfp4", False), ("fp16", "mxfp8e5", False),
616+
("mxfp8e4", "bf16", False), ("bf16", "mxfp4", True)])
586617
@pytest.mark.parametrize("mfma_nonkdim", [16, 32])
587618
@pytest.mark.parametrize("preshuffle", [True, False])
588619
@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] == 10, reason="Compilation bug for GB200.")
589620
@pytest.mark.skipif(is_hip() and not is_hip_cdna4(), reason="Scaled dot is not emulated on other archs yet.")
590-
def test_preshuffle_scale_mxfp_cdna4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, mfma_nonkdim, preshuffle, device):
621+
def test_preshuffle_scale_mxfp_cdna4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, DTYPE_A, DTYPE_B, FAST_MATH, mfma_nonkdim,
622+
preshuffle, device):
591623
# This test primarily evaluates correctness for efficient scale packing for MFMA-scaled instructions.
592624
#
593625
# Scales are stored as 8-bit tensors, where each element scales 32 values from the A or B operand tensors.
@@ -637,6 +669,12 @@ def test_preshuffle_scale_mxfp_cdna4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, mfma_no
637669
if preshuffle and (BLOCK_M < 32 or BLOCK_N < 32 or BLOCK_K < 256):
638670
pytest.skip("Minimal tile size for preshuffling is 32x32x256")
639671

672+
if not (DTYPE_A.startswith("mx") or DTYPE_B.startswith("mx")):
673+
pytest.skip("Requires at least 1 microscaling operand")
674+
675+
if is_cuda() and (DTYPE_A == "mxfp8e4" or DTYPE_B == "mxfp8e4"):
676+
pytest.skip("Skip fp8e4 on NV backend")
677+
640678
def shuffle_scales_cdna4(scales: torch.Tensor):
641679
if not preshuffle:
642680
return scales
@@ -665,63 +703,77 @@ def run_torch(x, w, x_scales, w_scales, dtype):
665703
x_f32 = x.to(torch.float32)
666704
w_f32 = w.to(torch.float32)
667705
# Next convert the e8m0 scales to f32.
668-
x_scales = x_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1).to(torch.float32)
669-
x_scales_f32 = e8m0_to_f32(x_scales)
670-
x_f32 = x_f32 * x_scales_f32
671-
w_scales = w_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1).to(torch.float32)
672-
w_scales_f32 = e8m0_to_f32(w_scales)
673-
w_f32 = w_f32 * w_scales_f32
706+
if x_scales is not None:
707+
x_scales = x_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1).to(torch.float32)
708+
x_scales_f32 = e8m0_to_f32(x_scales)
709+
x_f32 = x_f32 * x_scales_f32
710+
if w_scales is not None:
711+
w_scales = w_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1).to(torch.float32)
712+
w_scales_f32 = e8m0_to_f32(w_scales)
713+
w_f32 = w_f32 * w_scales_f32
674714
return torch.mm(x_f32, w_f32.T).to(dtype)
675715

676-
def generate_gemm_afp4wfp4_inputs(M, N, K):
716+
dtype_to_torch_type = {
717+
"fp16": torch.half, "bf16": torch.bfloat16, "mxfp8e5": torch.float8_e5m2, "mxfp8e4": torch.float8_e4m3fn
718+
}
719+
720+
dtype_to_triton_type = {"fp16": "fp16", "bf16": "bf16", "mxfp8e5": "e5m2", "mxfp8e4": "e4m3", "mxfp4": "e2m1"}
721+
722+
def generate_gemm_input(dim0, dim1, dtype):
677723
torch.manual_seed(5)
678724
SCALE_GROUP_SIZE = 32
679725

680-
x = MXFP4Tensor(size=(M, K), device="cuda").random()
681-
w = MXFP4Tensor(size=(N, K), device="cuda").random()
682-
683-
x_scales = torch.randint(124, 128, (K // SCALE_GROUP_SIZE, M), dtype=torch.uint8, device="cuda")
684-
w_scales = torch.randint(124, 128, (K // SCALE_GROUP_SIZE, N), dtype=torch.uint8, device="cuda")
685-
x_scales = x_scales.T
686-
w_scales = w_scales.T
687-
x_scales_shuffled = shuffle_scales_cdna4(x_scales)
688-
w_scales_shuffled = shuffle_scales_cdna4(w_scales)
689-
690-
return (
691-
x,
692-
w,
693-
x_scales,
694-
w_scales,
695-
x_scales_shuffled,
696-
w_scales_shuffled,
697-
)
698-
699-
x_mxfp4, w_mxfp4, x_scales, w_scales, x_scales_triton, w_scales_triton = generate_gemm_afp4wfp4_inputs(M, N, K)
700-
701-
x = x_mxfp4.to_packed_tensor(dim=1)
702-
w = w_mxfp4.to_packed_tensor(dim=1)
703-
704-
torch_out = run_torch(x_mxfp4, w_mxfp4, x_scales, w_scales, torch.float32)
705-
M, K = x.shape
706-
N, K = w.shape
726+
if dtype == "mxfp4":
727+
v = MXFP4Tensor(size=(dim0, dim1), device="cuda").random()
728+
elif dtype == "mxfp8e5":
729+
v = torch.randint(20, 40, (dim0, dim1), dtype=torch.uint8).view(torch.float8_e5m2).to(device)
730+
elif dtype == "mxfp8e4":
731+
v = torch.randint(20, 40, (dim0, dim1), dtype=torch.uint8).view(torch.float8_e4m3fn).to(device)
732+
elif dtype in ("fp16", "bf16"):
733+
v = torch.randn((dim0, dim1), device=device, dtype=dtype_to_torch_type[dtype])
734+
else:
735+
raise ValueError(f"Unsupported data type: {dtype}")
736+
737+
if dtype.startswith("mx"):
738+
scales = torch.randint(124, 128, (dim0, dim1 // SCALE_GROUP_SIZE), dtype=torch.uint8, device=device)
739+
scales_shuffled = shuffle_scales_cdna4(scales)
740+
else:
741+
scales = None
742+
scales_shuffled = None
743+
744+
return (v, scales, scales_shuffled)
745+
746+
x, x_scales, x_scales_triton = generate_gemm_input(M, K, DTYPE_A)
747+
w, w_scales, w_scales_triton = generate_gemm_input(N, K, DTYPE_B)
748+
749+
torch_out = run_torch(x, w, x_scales, w_scales, torch.float32)
750+
751+
if DTYPE_A == "mxfp4":
752+
x = x.to_packed_tensor(dim=1)
753+
754+
if DTYPE_B == "mxfp4":
755+
w = w.to_packed_tensor(dim=1)
756+
707757
w = w.T
708758
triton_out = torch.empty((M, N), device=x.device)
709759

760+
x_scales_strides = x_scales_triton.stride() if x_scales is not None else (None, None)
761+
w_scales_strides = w_scales_triton.stride() if w_scales is not None else (None, None)
762+
710763
kernel_kwargs = {}
711764
if is_hip():
712765
kernel_kwargs["matrix_instr_nonkdim"] = mfma_nonkdim
713766

714767
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
715-
k = _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4[grid](x, w, triton_out, x_scales_triton,
716-
w_scales_triton, M, N, K, x.stride(0), x.stride(1),
717-
w.stride(0), w.stride(1), 0, triton_out.stride(0),
718-
triton_out.stride(1), x_scales_triton.stride(0),
719-
x_scales_triton.stride(1), w_scales_triton.stride(0),
720-
w_scales_triton.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
721-
mfma_nonkdim, preshuffle, num_warps=8, num_stages=1,
722-
**kernel_kwargs)
768+
k = _gemm_kernel_preshuffled_scales_cdna4[grid](x, w, triton_out, x_scales_triton, w_scales_triton, M, N, K,
769+
x.stride(0), x.stride(1), w.stride(0), w.stride(1),
770+
triton_out.stride(0), triton_out.stride(1), *x_scales_strides,
771+
*w_scales_strides, dtype_to_triton_type[DTYPE_A],
772+
dtype_to_triton_type[DTYPE_B], BLOCK_M, BLOCK_N, BLOCK_K,
773+
mfma_nonkdim, preshuffle, fast_math=FAST_MATH, num_warps=8,
774+
num_stages=1, **kernel_kwargs)
723775
triton_out = triton_out.to(torch.float32)
724-
torch.testing.assert_close(torch_out, triton_out)
776+
torch.testing.assert_close(torch_out, triton_out, atol=2e-5, rtol=1e-4)
725777
if is_hip() and preshuffle:
726778
assert "tilesPerWarp = [2, 2]" in k.asm["ttgir"]
727779
assert "ds_read_u8" not in k.asm["amdgcn"]
@@ -738,7 +790,7 @@ def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_
738790
NUM_STAGES = min(NUM_STAGES, 2)
739791
elif BLOCK_K == 256:
740792
NUM_STAGES = min(NUM_STAGES, 3)
741-
#since the block size are big we use num_warps = 8 to avoid pressure problems.
793+
# since the block size are big we use num_warps = 8 to avoid pressure problems.
742794
num_warps = 8
743795
torch.manual_seed(42)
744796
dtype_src_str = "float8e5"

0 commit comments

Comments
 (0)