Skip to content

Commit 7dbc816

Browse files
[MoE training] torch.compile support for ScaledGroupedMMTensor (#2509)
1 parent 0935f66 commit 7dbc816

File tree

4 files changed

+24
-13
lines changed

4 files changed

+24
-13
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,18 @@
88
import torch
99
from torch.nn import functional as F
1010

11-
pytest.importorskip("triton", reason="Triton required to run this test")
12-
13-
from torchao.prototype.moe_training.utils import (
14-
_to_mxfp8_per_group_colwise,
15-
_to_mxfp8_per_group_rowwise,
16-
generate_jagged_offs,
17-
)
18-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
11+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
1912

2013
# We need to skip before doing any imports which would use triton, since
2114
# triton won't be available on CPU builds and torch < 2.5
2215
if not (
23-
TORCH_VERSION_AT_LEAST_2_5
16+
TORCH_VERSION_AT_LEAST_2_7
2417
and torch.cuda.is_available()
2518
and torch.cuda.get_device_capability()[0] >= 9
2619
):
2720
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2821

22+
pytest.importorskip("triton", reason="Triton required to run this test")
2923

3024
from torchao.float8.config import (
3125
Float8LinearConfig,
@@ -39,6 +33,11 @@
3933
_emulated_mxfp8_scaled_grouped_mm_2d_3d,
4034
_scaled_grouped_mm,
4135
)
36+
from torchao.prototype.moe_training.utils import (
37+
_to_mxfp8_per_group_colwise,
38+
_to_mxfp8_per_group_rowwise,
39+
generate_jagged_offs,
40+
)
4241
from torchao.prototype.mx_formats.mx_tensor import to_mx
4342
from torchao.testing.utils import skip_if_rocm
4443

test/prototype/moe_training/test_training.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
["does.not.exist"],
3535
],
3636
)
37-
def test_moe_float8_training(target_fqns: list[str]):
37+
@pytest.mark.parametrize("compile", [False, True])
38+
def test_moe_float8_training(target_fqns: list[str], compile: bool):
3839
model_args = TransformerModelArgs(
3940
moe_enabled=True,
4041
num_experts=8,
@@ -72,6 +73,11 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
7273
target_fqns=target_fqns,
7374
)
7475

76+
if compile:
77+
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
78+
model = torch.compile(model, fullgraph=False)
79+
ref_model = torch.compile(ref_model, fullgraph=False)
80+
7581
# inputs
7682
batch, seq, dim = 8, 2048, 256
7783
ref_x = torch.randn(

torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@
4242
for block_size_cols in block_sizes
4343
]
4444

45+
from torch.library import triton_op, wrap_triton
4546

47+
48+
@triton_op("torchao::triton_fp8_row_major_jagged_rowwise_scales", mutates_args={})
4649
def triton_fp8_row_major_jagged_rowwise_scales(
4750
hp_tensor: torch.Tensor,
4851
offsets: torch.Tensor,
@@ -90,7 +93,7 @@ def triton_fp8_row_major_jagged_rowwise_scales(
9093
triton.cdiv(m, meta["BLOCK_SIZE_ROWS"]),
9194
offsets.numel(),
9295
)
93-
_triton_fp8_row_major_jagged_rowwise_scales[grid](
96+
wrap_triton(_triton_fp8_row_major_jagged_rowwise_scales)[grid](
9497
hp_tensor,
9598
offsets,
9699
output_buffer,
@@ -204,6 +207,7 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
204207
tl.store(out_ptr + out_offs, fp8_data, mask=block_mask)
205208

206209

210+
@triton_op("torchao::triton_fp8_col_major_jagged_colwise_scales", mutates_args={})
207211
def triton_fp8_col_major_jagged_colwise_scales(
208212
hp_tensor: torch.Tensor,
209213
offsets: torch.Tensor,
@@ -251,7 +255,7 @@ def triton_fp8_col_major_jagged_colwise_scales(
251255
triton.cdiv(n, meta["BLOCK_SIZE_COLS"]),
252256
offsets.numel(),
253257
)
254-
_triton_fp8_col_major_jagged_colwise_scales[grid](
258+
wrap_triton(_triton_fp8_col_major_jagged_colwise_scales)[grid](
255259
hp_tensor,
256260
offsets,
257261
output_buffer,

torchao/prototype/moe_training/tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def __repr__(self):
123123
return f"ScaledGroupedMMTensor(data={self._data})"
124124

125125
def __tensor_flatten__(self):
126-
return ["_data"]
126+
# Metadata is empty but needed to make the subclass traceable for torch.compile.
127+
metadata = {}
128+
return ["_data"], metadata
127129

128130
@staticmethod
129131
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):

0 commit comments

Comments
 (0)