Skip to content

Commit c6918d1

Browse files
[moe training] update tests for torchtitan moe refactor
1 parent 1526dfe commit c6918d1

File tree

4 files changed

+24
-29
lines changed

4 files changed

+24
-29
lines changed

benchmarks/prototype/moe_training/benchmark_moe_layer.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,12 @@
3333
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
3434
from torchao.quantization.quant_api import quantize_
3535

36-
# this test requires torchtitan
36+
# this benchmark requires torchtitan
3737
try:
38-
from torchtitan.experiments.llama4.infra.expert_parallel import (
38+
from torchtitan.distributed.expert_parallel import (
3939
set_token_group_alignment_size_m,
4040
)
41-
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
42-
from torchtitan.experiments.llama4.model.moe import MoE
41+
from torchtitan.models.moe import MoE, MoEArgs
4342
except ImportError:
4443
pytest.skip(
4544
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
@@ -54,16 +53,15 @@ def bench_moe_float8_training_fsdp(enable_profile=False):
5453

5554
# define model args
5655
target_fqns = ["experts"]
57-
model_args = TransformerModelArgs(
58-
moe_enabled=True,
56+
model_args = MoEArgs(
5957
num_experts=16,
60-
dim=5120,
6158
)
6259
init_std = 0.02
6360
device = torch.device("cuda")
6461

6562
# reference bf16 MoE
66-
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
63+
dim, hidden_dim = 5120, 4 * 5120
64+
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
6765
torch.manual_seed(42)
6866
ref_model.init_weights(init_std, device)
6967

@@ -90,7 +88,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
9088
fully_shard(ref_model)
9189

9290
# inputs (llama4 shapes)
93-
batch, seq, dim = 1, 8192, 5120
91+
batch, seq = 1, 8192
9492
ref_x = torch.randn(
9593
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
9694
)

test/prototype/moe_training/test_training.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@
2222

2323
# this test requires torchtitan
2424
try:
25-
from torchtitan.experiments.llama4.infra.expert_parallel import (
25+
from torchtitan.distributed.expert_parallel import (
2626
set_token_group_alignment_size_m,
2727
)
28-
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
29-
from torchtitan.experiments.llama4.model.moe import MoE
28+
from torchtitan.models.moe import MoE, MoEArgs
3029
except ImportError:
3130
pytest.skip(
3231
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
@@ -47,16 +46,15 @@ def test_moe_float8_training(target_fqns: list[str], compile: bool):
4746
# has the contraction dim be divisible by 16. 16 byte alignment is required
4847
# for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
4948
set_token_group_alignment_size_m(16)
50-
model_args = TransformerModelArgs(
51-
moe_enabled=True,
49+
model_args = MoEArgs(
5250
num_experts=8,
53-
dim=256,
5451
)
5552
init_std = 0.02
5653
device = torch.device("cuda")
5754

5855
# reference bf16 MoE
59-
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
56+
dim, hidden_dim = 5120, 4 * 5120
57+
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
6058
torch.manual_seed(42)
6159
ref_model.init_weights(init_std, device)
6260

@@ -75,22 +73,21 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
7573
return False
7674

7775
# quantize test model
78-
config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE)
76+
config = MoETrainingConfig()
7977
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
8078

8179
# validate that only the experts were converted
8280
_validate_model_conversion(
8381
model,
8482
target_fqns=target_fqns,
8583
)
86-
8784
if compile:
8885
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
8986
model = torch.compile(model, fullgraph=False)
9087
ref_model = torch.compile(ref_model, fullgraph=False)
9188

9289
# inputs
93-
batch, seq, dim = 8, 2048, 256
90+
batch, seq = 8, 2048
9491
ref_x = torch.randn(
9592
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
9693
)
@@ -145,18 +142,15 @@ def test_moe_mxfp8_training(target_fqns: list[str]):
145142
# Token groups must be divisible by 32 for mxfp8
146143
set_token_group_alignment_size_m(block_size)
147144

148-
model_args = TransformerModelArgs(
149-
moe_enabled=True,
145+
model_args = MoEArgs(
150146
num_experts=8,
151-
dim=256,
152-
multiple_of=block_size,
153-
ffn_dim_multiplier=1.0,
154147
)
155148
init_std = 0.02
156149
device = torch.device("cuda")
157150

158151
# reference bf16 MoE
159-
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
152+
dim, hidden_dim = 256, 4 * 256
153+
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
160154
torch.manual_seed(42)
161155
ref_model.init_weights(init_std, device)
162156

@@ -185,7 +179,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
185179
)
186180

187181
# inputs
188-
batch, seq, dim = 8, 2048, 256
182+
batch, seq = 8, 2048
189183
ref_x = torch.randn(
190184
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
191185
)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ def _scaled_grouped_mm(
4848
"""
4949
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
5050
if scaling_type == MoEScalingType.FP8_ROWWISE:
51-
logger.info("Using fp8 rowwise scaled_grouped_mm")
51+
print("Using fp8 rowwise scaled_grouped_mm")
5252
return _Float8GroupedMM.apply(
5353
A,
5454
B_t,
5555
offs,
5656
out_dtype,
5757
)
5858
elif scaling_type == MoEScalingType.MXFP8:
59-
logger.info("Using mxfp8 scaled_grouped_mm")
59+
print("Using mxfp8 scaled_grouped_mm")
6060
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
6161
return _MXFP8GroupedMM.apply(
6262
A,

torchao/prototype/moe_training/tensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,12 @@ def __torch_function__(cls, func, types, args, kwargs={}):
9797
A_is_2d = A.dim() == 2
9898
B_is_3d = B.dim() == 3
9999
has_offs = kwargs.get(cls.offs_arg_name) is not None
100+
other_args = args[2:]
100101
if A_is_2d and B_is_3d and has_offs:
101102
return _scaled_grouped_mm(
102-
*args,
103+
A,
104+
B,
105+
*other_args,
103106
scaling_type=scaling_type,
104107
**kwargs,
105108
)

0 commit comments

Comments
 (0)