Skip to content

Commit be40518

Browse files
[moe training] set token group alignment size to 16 for fp8 training test (#2678)
1 parent 7a13eb0 commit be40518

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

test/prototype/moe_training/test_training.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
# this test requires torchtitan
2121
try:
22+
from torchtitan.experiments.llama4.infra.expert_parallel import (
23+
set_token_group_alignment_size_m,
24+
)
2225
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
2326
from torchtitan.experiments.llama4.model.moe import MoE
2427
except ImportError:
@@ -36,6 +39,11 @@
3639
)
3740
@pytest.mark.parametrize("compile", [False, True])
3841
def test_moe_float8_training(target_fqns: list[str], compile: bool):
42+
# Set token group alignment size to 16. This is required so that
43+
# each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
44+
# has the contraction dim be divisible by 16. 16 byte alignment is required
45+
# for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
46+
set_token_group_alignment_size_m(16)
3947
model_args = TransformerModelArgs(
4048
moe_enabled=True,
4149
num_experts=8,

0 commit comments

Comments
 (0)