Skip to content

Commit bfc04bc

Browse files
authored
[Bench][AMD] Support Scale Preshuffling on GFX950 (#7836)
This PR added weight scale preshuffling to the benchmark for gfx950 hardware.
1 parent 48862cd commit bfc04bc

File tree

6 files changed

+84
-9
lines changed

6 files changed

+84
-9
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
104104
# -- numerics --
105105
opt1 = dict()
106106
opt2 = dict()
107-
if w_dtype == "mx4" and not is_hip():
107+
if w_dtype == "mx4":
108108
num_warps = 4 if batch <= 512 else 8
109109
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
110110
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(

python/triton_kernels/tests/test_matmul.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,12 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
297297
pytest.skip("fused scatter scratchpad not supported with split_k")
298298
if hbm_swizzling:
299299
if is_hip():
300-
pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
300+
if not is_hip_cdna4():
301+
pytest.skip("Scale preshuffling on AMD GPU has not been emulated on non-CDNA4 arch yet.")
302+
if "mx" not in weight_dtype_str:
303+
pytest.skip("Non-scale swizzling not supported on CDNA4 yet")
304+
if n % 32 != 0 or k % (32 * 8) != 0:
305+
pytest.skip(f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU")
301306
if torch.cuda.get_device_capability()[0] < 9:
302307
pytest.skip("NYI. Ampere swizzling.")
303308
if torch.cuda.get_device_capability()[0] < 10:
@@ -327,6 +332,15 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
327332
"is_persistent": is_persistent,
328333
"epilogue_subtile": epilogue_subtile,
329334
}
335+
336+
if is_hip() and hbm_swizzling and "float4" in weight_dtype_str:
337+
# Minimum block size to satisfy scale preshuffling
338+
constraints.update({
339+
"block_m": 32,
340+
"block_n": 32,
341+
"block_k": 256
342+
})
343+
330344
opt_flags.update_opt_flags_constraints(constraints)
331345

332346
weight_mxfp = weight_dtype_str.startswith("mx")

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
66
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
77
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
8+
from triton_kernels.tensor_details.layout_details.cdna4_scale import unswizzle_mx_scale_cdna4
89
from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
910
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
1011
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
@@ -209,6 +210,13 @@ def _matmul_ogs(
209210
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * 32
210211
SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 32
211212
stride_scale_k = stride_w_mx_k
213+
elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
214+
tl.static_assert(stride_w_mx_k is not None)
215+
tl.static_assert(stride_w_mx_n is not None)
216+
NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 32
217+
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * NON_K_PRESHUFFLE_BLOCK_SIZE
218+
SCALE_BLOCK_N: tl.constexpr = BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE
219+
stride_scale_k = stride_w_mx_k
212220
else:
213221
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K
214222
SCALE_BLOCK_N: tl.constexpr = BLOCK_N
@@ -281,6 +289,8 @@ def _matmul_ogs(
281289
# Handshake with the swizzling code
282290
num_warps: tl.constexpr = tl.extra.cuda.num_warps()
283291
w_scales = unswizzle_mxfp4_scale_hopper(tl.load(WMxScalePtrs), mx_axis=1, num_warps=num_warps)
292+
elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
293+
w_scales = unswizzle_mx_scale_cdna4(tl.load(WMxScalePtrs), BLOCK_N, MX_SCALE_BLOCK_K)
284294
else:
285295
w_scales = tl.load(WMxScalePtrs, mask=mask_k_scale[None, :])
286296

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def make_default_opt_flags_amd(
4646
epilogue_effective_itemsize,
4747
constraints,
4848
):
49-
constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
49+
constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
5050
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
5151
# tokens per expert
5252
if routing_data is None:
@@ -86,6 +86,8 @@ def make_default_opt_flags_amd(
8686
# TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
8787
if constraints.get("block_k", None) is not None:
8888
block_k = constraints["block_k"]
89+
if constraints.get("block_n", None) is not None:
90+
block_n = constraints["block_n"]
8991
is_persistent = constraints.get("is_persistent", False)
9092
# split_k:
9193
if constraints.get("split_k", None) is not None:

python/triton_kernels/triton_kernels/tensor_details/layout.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
from .layout_details.blackwell_scale import BlackwellMXScaleLayout
33
from .layout_details.hopper_scale import HopperMXScaleLayout
44
from .layout_details.hopper_value import HopperMXValueLayout
5+
from .layout_details.cdna4_scale import CDNA4MXScaleLayout
56
from .layout_details.strided import StridedLayout
6-
from ..target_info import cuda_capability_geq
7+
from ..target_info import cuda_capability_geq, is_hip_cdna4
78

89
__all__ = [
910
"Layout",
1011
"BlackwellMXScaleLayout",
1112
"HopperMXScaleLayout",
1213
"HopperMXValueLayout",
14+
"CDNA4MXScaleLayout",
1315
"StridedLayout",
1416
]
1517

@@ -24,9 +26,12 @@ def make_default_matmul_mxfp4_w_layout(mx_axis: int):
2426

2527

2628
def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
27-
if cuda_capability_geq(10):
28-
return BlackwellMXScaleLayout, dict()
29-
elif cuda_capability_geq(9):
30-
return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
29+
if is_hip_cdna4():
30+
return CDNA4MXScaleLayout, dict()
3131
else:
32-
return StridedLayout, dict()
32+
if cuda_capability_geq(10):
33+
return BlackwellMXScaleLayout, dict()
34+
elif cuda_capability_geq(9):
35+
return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
36+
37+
return StridedLayout, dict()
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import triton
2+
import triton.language as tl
3+
from .base import Layout
4+
5+
NON_K_PRESHUFFLE_BLOCK_SIZE = 32
6+
7+
8+
class CDNA4MXScaleLayout(Layout):
9+
name: str = "CDNA4_SCALE"
10+
11+
def __init__(self, shape) -> None:
12+
super().__init__(shape)
13+
14+
def swizzle_data(self, data):
15+
block_shape = data.shape
16+
SCALE_K = block_shape[-2]
17+
N = block_shape[-1]
18+
data = data.transpose(-1, -2)
19+
data = data.view(-1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1)
20+
data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous()
21+
if len(block_shape) == 3:
22+
E = block_shape[0]
23+
data = data.reshape(E, N // 32, SCALE_K * 32)
24+
else:
25+
assert len(block_shape) == 2
26+
data = data.reshape(N // 32, SCALE_K * 32)
27+
return data.transpose(-1, -2)
28+
29+
def unswizzle_data(self, data):
30+
raise NotImplementedError()
31+
32+
def swizzle_block_shape(self, block_shape):
33+
SCALE_K = block_shape[-2]
34+
N = block_shape[-1]
35+
return block_shape[:-2] + [N // 32, SCALE_K * 32]
36+
37+
38+
@triton.jit
39+
def unswizzle_mx_scale_cdna4(x, BLOCK_N: tl.constexpr, MX_SCALE_BLOCK_K: tl.constexpr,
40+
N_PRESHUFFLE_FACTOR: tl.constexpr = NON_K_PRESHUFFLE_BLOCK_SIZE):
41+
x = x.reshape(BLOCK_N // N_PRESHUFFLE_FACTOR, MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2, 1)
42+
x = x.permute(0, 5, 3, 1, 4, 2, 6)
43+
x = x.reshape(BLOCK_N, MX_SCALE_BLOCK_K)
44+
return x

0 commit comments

Comments
 (0)