22
22
23
23
# this test requires torchtitan
24
24
try :
25
- from torchtitan .experiments . llama4 . infra .expert_parallel import (
25
+ from torchtitan .distributed .expert_parallel import (
26
26
set_token_group_alignment_size_m ,
27
27
)
28
- from torchtitan .experiments .llama4 .model .args import TransformerModelArgs
29
- from torchtitan .experiments .llama4 .model .moe import MoE
28
+ from torchtitan .models .moe import MoE , MoEArgs
30
29
except ImportError :
31
30
pytest .skip (
32
31
"torchtitan not installed, skipping MoE tests." , allow_module_level = True
@@ -47,16 +46,15 @@ def test_moe_float8_training(target_fqns: list[str], compile: bool):
47
46
# has the contraction dim be divisible by 16. 16 byte alignment is required
48
47
# for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
49
48
set_token_group_alignment_size_m (16 )
50
- model_args = TransformerModelArgs (
51
- moe_enabled = True ,
49
+ model_args = MoEArgs (
52
50
num_experts = 8 ,
53
- dim = 256 ,
54
51
)
55
52
init_std = 0.02
56
53
device = torch .device ("cuda" )
57
54
58
55
# reference bf16 MoE
59
- ref_model = MoE (model_args ).to (torch .bfloat16 ).cuda ()
56
+ dim , hidden_dim = 256 , 4 * 256
57
+ ref_model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).cuda ()
60
58
torch .manual_seed (42 )
61
59
ref_model .init_weights (init_std , device )
62
60
@@ -75,22 +73,21 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
75
73
return False
76
74
77
75
# quantize test model
78
- config = MoETrainingConfig (scaling_type = MoEScalingType . FP8_ROWWISE )
76
+ config = MoETrainingConfig ()
79
77
quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
80
78
81
79
# validate that only the experts were converted
82
80
_validate_model_conversion (
83
81
model ,
84
82
target_fqns = target_fqns ,
85
83
)
86
-
87
84
if compile :
88
85
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
89
86
model = torch .compile (model , fullgraph = False )
90
87
ref_model = torch .compile (ref_model , fullgraph = False )
91
88
92
89
# inputs
93
- batch , seq , dim = 8 , 2048 , 256
90
+ batch , seq = 8 , 2048
94
91
ref_x = torch .randn (
95
92
batch , seq , dim , dtype = torch .bfloat16 , requires_grad = True , device = device
96
93
)
@@ -145,18 +142,15 @@ def test_moe_mxfp8_training(target_fqns: list[str]):
145
142
# Token groups must be divisible by 32 for mxfp8
146
143
set_token_group_alignment_size_m (block_size )
147
144
148
- model_args = TransformerModelArgs (
149
- moe_enabled = True ,
145
+ model_args = MoEArgs (
150
146
num_experts = 8 ,
151
- dim = 256 ,
152
- multiple_of = block_size ,
153
- ffn_dim_multiplier = 1.0 ,
154
147
)
155
148
init_std = 0.02
156
149
device = torch .device ("cuda" )
157
150
158
151
# reference bf16 MoE
159
- ref_model = MoE (model_args ).to (torch .bfloat16 ).cuda ()
152
+ dim , hidden_dim = 256 , 4 * 256
153
+ ref_model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).cuda ()
160
154
torch .manual_seed (42 )
161
155
ref_model .init_weights (init_std , device )
162
156
@@ -185,7 +179,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
185
179
)
186
180
187
181
# inputs
188
- batch , seq , dim = 8 , 2048 , 256
182
+ batch , seq = 8 , 2048
189
183
ref_x = torch .randn (
190
184
batch , seq , dim , dtype = torch .bfloat16 , requires_grad = True , device = device
191
185
)
0 commit comments