Skip to content

Commit 478c5f2

Browse files
[moe training] fix scaling type bug; refactor distributed tests (#2749)
1 parent c1223e1 commit 478c5f2

File tree

4 files changed

+79
-49
lines changed

4 files changed

+79
-49
lines changed

test/prototype/moe_training/test_fsdp.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@
3535

3636
# this test requires torchtitan
3737
try:
38-
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
39-
from torchtitan.experiments.llama4.model.moe import MoE
38+
from torchtitan.distributed.expert_parallel import (
39+
set_token_group_alignment_size_m,
40+
)
41+
from torchtitan.models.moe import MoE, MoEArgs
4042
except ImportError:
4143
pytest.skip(
4244
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
@@ -49,18 +51,20 @@ def test_moe_float8_training_fsdp():
4951
# setup distributed for fsdp
5052
setup_distributed()
5153

54+
# token group aligment size must be 16 for fp8
55+
set_token_group_alignment_size_m(16)
56+
5257
# define model args
5358
target_fqns = ["experts"]
54-
model_args = TransformerModelArgs(
55-
moe_enabled=True,
59+
model_args = MoEArgs(
5660
num_experts=8,
57-
dim=256,
5861
)
5962
init_std = 0.02
6063
device = torch.device("cuda")
6164

6265
# reference bf16 MoE
63-
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
66+
dim, hidden_dim = 5120, 4 * 5120
67+
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
6468
torch.manual_seed(42)
6569
ref_model.init_weights(init_std, device)
6670

@@ -93,7 +97,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
9397
fully_shard(ref_model)
9498

9599
# inputs
96-
batch, seq, dim = 8, 2048, 256
100+
batch, seq = 8, 2048
97101
ref_x = torch.randn(
98102
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
99103
)
@@ -105,7 +109,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
105109

106110
# validate output
107111
out_sqnr = compute_error(out, ref_out)
108-
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
112+
min_out_sqnr = 29.0
113+
assert out_sqnr.item() >= min_out_sqnr, (
114+
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
115+
)
109116

110117
# compute loss
111118
labels = torch.ones_like(ref_out)
@@ -118,15 +125,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
118125

119126
# validate input gradient
120127
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
121-
assert input_grad_sqnr.item() >= 30.0, (
122-
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
128+
min_input_grad_sqnr = 29.0
129+
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
130+
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
123131
)
124132

125133
# validate param gradients
134+
min_param_grad_sqnr = 23.0
126135
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
127136
param_grad_sqnr = compute_error(param1.grad, param2.grad)
128-
assert param_grad_sqnr.item() >= 25.0, (
129-
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
137+
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
138+
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
130139
)
131140

132141
dist.destroy_process_group()

test/prototype/moe_training/test_fsdp_tp.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@
4949

5050
# this test requires torchtitan
5151
try:
52-
from torchtitan.experiments.llama4.infra.expert_parallel import (
52+
from torchtitan.distributed.expert_parallel import (
5353
ExpertParallel,
5454
ExpertTensorParallel,
5555
NoParallel,
5656
TensorParallel,
57+
set_token_group_alignment_size_m,
5758
)
58-
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
59-
from torchtitan.experiments.llama4.model.moe import MoE
59+
from torchtitan.models.moe import MoE, MoEArgs
6060
except ImportError:
6161
pytest.skip(
6262
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
@@ -74,21 +74,22 @@
7474
def test_moe_float8_training_fsdp_tp(target_fqns: list[str]):
7575
assert torch.cuda.is_available()
7676

77+
# token group aligment size must be 16 for fp8
78+
set_token_group_alignment_size_m(16)
79+
7780
# setup distributed for tp
7881
mesh = setup_distributed()
7982

8083
# define model args
81-
model_args = TransformerModelArgs(
82-
moe_enabled=True,
84+
model_args = MoEArgs(
8385
num_experts=8,
84-
dim=256,
85-
vocab_size=1024,
8686
)
87+
dim, hidden_dim = 5120, 4 * 5120
8788
init_std = 0.02
8889
device = torch.device("cuda")
8990

9091
# reference bf16 MoE
91-
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
92+
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
9293
torch.manual_seed(1)
9394
ref_model.init_weights(init_std, device)
9495

@@ -146,7 +147,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
146147
)
147148

148149
# inputs
149-
batch, seq, dim = 8, 2048, 256
150+
batch, seq = 8, 2048
150151
ref_x = torch.randn(
151152
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
152153
)
@@ -158,7 +159,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
158159

159160
# validate output
160161
out_sqnr = compute_error(out, ref_out)
161-
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
162+
min_out_sqnr = 30.0
163+
assert out_sqnr.item() >= min_out_sqnr, (
164+
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
165+
)
162166

163167
# compute loss
164168
labels = torch.ones_like(ref_out)
@@ -171,15 +175,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
171175

172176
# validate input gradient
173177
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
174-
assert input_grad_sqnr.item() >= 28.0, (
175-
f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}."
178+
min_input_grad_sqnr = 28.0
179+
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
180+
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
176181
)
177182

178183
# validate param gradients
184+
min_param_grad_sqnr = 23.0
179185
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
180186
param_grad_sqnr = compute_error(param1.grad, param2.grad)
181-
assert param_grad_sqnr.item() >= 25.0, (
182-
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
187+
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
188+
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
183189
)
184190

185191
dist.destroy_process_group()

test/prototype/moe_training/test_tp.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@
4949

5050
# this test requires torchtitan
5151
try:
52-
from torchtitan.experiments.llama4.infra.expert_parallel import (
52+
from torchtitan.distributed.expert_parallel import (
5353
ExpertParallel,
5454
ExpertTensorParallel,
5555
NoParallel,
5656
TensorParallel,
57+
set_token_group_alignment_size_m,
5758
)
58-
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
59-
from torchtitan.experiments.llama4.model.moe import MoE
59+
from torchtitan.models.moe import MoE, MoEArgs
6060
except ImportError:
6161
pytest.skip(
6262
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
@@ -74,21 +74,22 @@
7474
def test_moe_float8_training_tp(target_fqns: list[str]):
7575
assert torch.cuda.is_available()
7676

77+
# token group aligment size must be 16 for fp8
78+
set_token_group_alignment_size_m(16)
79+
7780
# setup distributed for tp
7881
mesh = setup_distributed()
7982

8083
# define model args
81-
model_args = TransformerModelArgs(
82-
moe_enabled=True,
84+
model_args = MoEArgs(
8385
num_experts=8,
84-
dim=256,
85-
vocab_size=1024,
8686
)
87+
dim, hidden_dim = 5120, 4 * 5120
8788
init_std = 0.02
8889
device = torch.device("cuda")
8990

9091
# reference bf16 MoE
91-
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
92+
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
9293
torch.manual_seed(1)
9394
ref_model.init_weights(init_std, device)
9495

@@ -141,7 +142,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
141142
)
142143

143144
# inputs
144-
batch, seq, dim = 8, 2048, 256
145+
batch, seq = 8, 2048
145146
ref_x = torch.randn(
146147
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
147148
)
@@ -153,7 +154,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
153154

154155
# validate output
155156
out_sqnr = compute_error(out, ref_out)
156-
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
157+
min_out_sqnr = 29.0
158+
assert out_sqnr.item() >= min_out_sqnr, (
159+
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
160+
)
157161

158162
# compute loss
159163
labels = torch.ones_like(ref_out)
@@ -166,15 +170,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
166170

167171
# validate input gradient
168172
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
169-
assert input_grad_sqnr.item() >= 28.0, (
170-
f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}."
173+
min_input_grad_sqnr = 28.0
174+
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
175+
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
171176
)
172177

173178
# validate param gradients
179+
min_param_grad_sqnr = 23.0
174180
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
175181
param_grad_sqnr = compute_error(param1.grad, param2.grad)
176-
assert param_grad_sqnr.item() >= 25.0, (
177-
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
182+
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
183+
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
178184
)
179185

180186
dist.destroy_process_group()
@@ -203,17 +209,17 @@ def apply_moe_ep_tp(
203209
moe_layer_plan = {
204210
# input / output sharding on the seqlen dim
205211
# all-gather for input, reduce-scatter for output
206-
"moe": PrepareModuleInputOutput(
212+
"": PrepareModuleInputOutput(
207213
input_layouts=(Shard(1),),
208214
desired_input_layouts=(Replicate(),),
209215
use_local_input=True,
210216
output_layouts=(Partial(),),
211217
desired_output_layouts=(Shard(1),),
212218
),
213219
# replicate computation for the router
214-
"moe.router.gate": NoParallel(),
220+
"router.gate": NoParallel(),
215221
# input Replicate, output Partial
216-
"moe.shared_expert": TensorParallel(),
222+
"shared_expert": TensorParallel(),
217223
}
218224
parallelize_module(
219225
module=model,

torchao/prototype/moe_training/tensor.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,25 @@ def __torch_function__(cls, func, types, args, kwargs={}):
114114

115115
@classmethod
116116
def __torch_dispatch__(cls, func, types, args, kwargs={}):
117-
# detach is special case
118-
scaling_type = args[0].scaling_type
119-
if func == torch.ops.aten.detach.default:
120-
return ScaledGroupedMMTensor(args[0]._data, scaling_type)
117+
# unwrap args/kwargs and extract scaling_type
118+
scaling_type = None
119+
120+
def unwrap(t):
121+
nonlocal scaling_type
122+
if scaling_type is None:
123+
scaling_type = t.scaling_type
124+
else:
125+
assert t.scaling_type == scaling_type
126+
return t._data
121127

122-
# unwrap args/kwargs
123-
unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x
124128
args, kwargs = pytree.tree_map_only(
125129
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
126130
)
131+
assert scaling_type is not None
132+
133+
# detach is special case
134+
if func == torch.ops.aten.detach.default:
135+
return ScaledGroupedMMTensor(args[0], scaling_type)
127136

128137
# perform op
129138
out = func(*args, **kwargs)

0 commit comments

Comments
 (0)