Skip to content

Commit 28533b1

Browse files
authored
[TRITON_KERNELS] Apply MXFP4 Hopper layout on A100 (#8474)
MXFP4 matmul performance is better with this layout on A100, so change the default layout. Also changed the layout names to reflect that it's used for both Hopper and Ampere. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because the change should be covered with existing tests. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 3200f5a commit 28533b1

File tree

5 files changed

+18
-19
lines changed

5 files changed

+18
-19
lines changed

python/triton_kernels/tests/test_tensor_details/test_layout_hopper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
from triton._internal_testing import is_cuda
33
from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4
4-
from triton_kernels.tensor_details.layout import HopperMXScaleLayout, HopperMXValueLayout
4+
from triton_kernels.tensor_details.layout import HopperAmpereMXScaleLayout, HopperAmpereMXValueLayout
55
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
66
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
77
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
@@ -25,7 +25,7 @@ def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
2525
x = x.mT
2626
if x.shape[1 - mx_axis] < 32:
2727
pytest.skip("Not enough elements along non-mx axis")
28-
layout = HopperMXValueLayout(x.shape, mx_axis, mma_version)
28+
layout = HopperAmpereMXValueLayout(x.shape, mx_axis, mma_version)
2929
res = layout.unswizzle_data(layout.swizzle_data(x))
3030
assert (res == x).all()
3131

@@ -35,7 +35,7 @@ def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
3535
@pytest.mark.parametrize("shape", [(256, 64), (256, 128), (256, 256)])
3636
def test_mxfp4_scale_roundtrip(shape, mx_axis, num_warps):
3737
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
38-
layout = HopperMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps)
38+
layout = HopperAmpereMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps)
3939
res = layout.unswizzle_data(layout.swizzle_data(x))
4040
assert (res[:shape[0], :shape[1]] == x).all()
4141

@@ -84,8 +84,8 @@ def test_upcast_mxfp4_to_bf16():
8484
x_bf16 = upcast_from_mxfp(x_fp4_val, x_fp4_scale, x.dtype, axis=mx_axis)
8585
x_fp4_val = wrap_torch_tensor(x_fp4_val, dtype=FP4)
8686
x_fp4_scale = wrap_torch_tensor(x_fp4_scale)
87-
x_fp4_val = convert_layout(x_fp4_val, HopperMXValueLayout, mx_axis=mx_axis)
88-
x_fp4_scale = convert_layout(x_fp4_scale, HopperMXScaleLayout, mx_axis=mx_axis, num_warps=num_warps)
87+
x_fp4_val = convert_layout(x_fp4_val, HopperAmpereMXValueLayout, mx_axis=mx_axis)
88+
x_fp4_scale = convert_layout(x_fp4_scale, HopperAmpereMXScaleLayout, mx_axis=mx_axis, num_warps=num_warps)
8989
y = torch.empty_like(x_bf16)
9090
_upcast_mxfp4_to_bf16[(1, )](
9191
y, x_fp4_val.storage.data, x_fp4_scale.storage.data, #

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import triton
33
from triton_kernels import target_info
44
from triton_kernels.tensor import get_layout, bitwidth, FP4
5-
from triton_kernels.tensor_details.layout import HopperMXScaleLayout
5+
from triton_kernels.tensor_details.layout import HopperAmpereMXScaleLayout
66
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
77

88

@@ -18,7 +18,7 @@ def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n):
1818
def compute_block_n(n: int, arch, precision_config):
1919
# block_n:
2020
layout = get_layout(precision_config.weight_scale)
21-
if isinstance(layout, HopperMXScaleLayout) and layout.num_warps == 4:
21+
if isinstance(layout, HopperAmpereMXScaleLayout) and layout.num_warps == 4:
2222
return 128, 128
2323
elif precision_config.max_num_imprecise_acc is None and n > 128:
2424
return 256, 256
@@ -60,7 +60,7 @@ def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
6060

6161
def compute_num_warps(block_m, block_n, is_persistent: bool, precision_config):
6262
layout = get_layout(precision_config.weight_scale)
63-
if isinstance(layout, HopperMXScaleLayout):
63+
if isinstance(layout, HopperAmpereMXScaleLayout):
6464
return layout.num_warps
6565
return max(block_m * block_n // 4096, 4 if is_persistent else 1)
6666

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .layout_details.base import Layout
22
from .layout_details.blackwell_scale import BlackwellMXScaleLayout
33
from .layout_details.blackwell_value import BlackwellMXValueLayout
4-
from .layout_details.hopper_scale import HopperMXScaleLayout
5-
from .layout_details.hopper_value import HopperMXValueLayout
4+
from .layout_details.hopper_scale import HopperAmpereMXScaleLayout
5+
from .layout_details.hopper_value import HopperAmpereMXValueLayout
66
from .layout_details.cdna4_scale import CDNA4MXScaleLayout
77
from .layout_details.strided import StridedLayout
88
from ..target_info import cuda_capability_geq, is_hip_cdna4
@@ -11,19 +11,18 @@
1111
"Layout",
1212
"BlackwellMXValueLayout",
1313
"BlackwellMXScaleLayout",
14-
"HopperMXScaleLayout",
15-
"HopperMXValueLayout",
14+
"HopperAmpereMXScaleLayout",
15+
"HopperAmpereMXValueLayout",
1616
"CDNA4MXScaleLayout",
1717
"StridedLayout",
1818
]
1919

2020

2121
def make_default_matmul_mxfp4_w_layout(mx_axis: int):
2222
if cuda_capability_geq(10):
23-
# return StridedLayout, dict()
2423
return BlackwellMXValueLayout, dict()
25-
elif cuda_capability_geq(9):
26-
return HopperMXValueLayout, {"mx_axis": mx_axis}
24+
elif cuda_capability_geq(8):
25+
return HopperAmpereMXValueLayout, {"mx_axis": mx_axis}
2726
else:
2827
return StridedLayout, dict()
2928

@@ -34,7 +33,7 @@ def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
3433
else:
3534
if cuda_capability_geq(10):
3635
return BlackwellMXScaleLayout, dict()
37-
elif cuda_capability_geq(9):
38-
return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
36+
elif cuda_capability_geq(8):
37+
return HopperAmpereMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
3938

4039
return StridedLayout, dict()

python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .base import Layout
55

66

7-
class HopperMXScaleLayout(Layout):
7+
class HopperAmpereMXScaleLayout(Layout):
88
name: str = "HOPPER_SCALE"
99

1010
def __init__(self, shape, mx_axis, num_warps=8) -> None:

python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _unpack_bits(x, mx_axis: int):
8282
# -----------------------------------------------------------------------
8383

8484

85-
class HopperMXValueLayout(Layout):
85+
class HopperAmpereMXValueLayout(Layout):
8686
name: str = "HOPPER_VALUE"
8787

8888
def __init__(self, shape, mx_axis, mma_version=3):

0 commit comments

Comments
 (0)