Skip to content

Commit 5c04917

Browse files
ilmlclaude
andcommitted
fix: pass layer_wise_distributed_optimizer via config instead of kwarg
The test was passing layer_wise_distributed_optimizer as a keyword arg to get_megatron_muon_optimizer(), but that function doesn't accept it. Set it on the OptimizerConfig object instead. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6521ee2 commit 5c04917

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/unit_tests/test_layer_wise_optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,11 @@ def create_model_and_optimizer(
124124
pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True)
125125
pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group()
126126

127+
optimizer_config.use_layer_wise_distributed_optimizer = use_layer_wise
127128
optimizer = get_megatron_muon_optimizer(
128129
config=optimizer_config,
129130
model_chunks=[model],
130131
use_gloo_process_groups=True,
131-
layer_wise_distributed_optimizer=use_layer_wise,
132132
pg_collection=pg_collection,
133133
)
134134
return model, optimizer, pg_collection
@@ -197,11 +197,11 @@ def create_model_and_optimizer_with_overlap_param_gather(
197197
pg_collection.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True)
198198
pg_collection.expt_dp = parallel_state.get_expert_data_parallel_group()
199199

200+
optimizer_config.use_layer_wise_distributed_optimizer = True
200201
optimizer = get_megatron_muon_optimizer(
201202
config=optimizer_config,
202203
model_chunks=[model],
203204
use_gloo_process_groups=True,
204-
layer_wise_distributed_optimizer=True,
205205
pg_collection=pg_collection,
206206
)
207207
return model, optimizer, pg_collection
@@ -399,11 +399,11 @@ def test_bf16_error(self):
399399
use_distributed_optimizer=False,
400400
muon_tp_mode="duplicated",
401401
)
402+
optimizer_config.use_layer_wise_distributed_optimizer = False
402403
muon_optimizer = get_megatron_muon_optimizer(
403404
config=optimizer_config,
404405
model_chunks=[model],
405406
use_gloo_process_groups=True,
406-
layer_wise_distributed_optimizer=False,
407407
pg_collection=pg_collection,
408408
)
409409

0 commit comments

Comments
 (0)