2323import torch
2424from megatron .core .datasets .gpt_dataset import GPTDatasetConfig as MCoreGPTDatasetConfig
2525from megatron .core .distributed import DistributedDataParallelConfig as MCoreDistributedDataParallelConfig
26- from megatron .core .optimizer import (
27- OptimizerConfig as MCoreOptimizerConfig ,
28- )
26+ from megatron .core .optimizer import OptimizerConfig as MCoreOptimizerConfig
2927from megatron .core .optimizer import (
3028 ParamGroupOverride ,
3129 ParamKey ,
@@ -286,8 +284,8 @@ def build_config_overrides(self, context: OptimizerConfigOverrideProviderContext
286284 For Qwen3-Next models, weight decay is applied to q_layernorm and k_layernorm.
287285
288286 Args:
289- scheduler_config: Scheduler configuration containing weight decay settings
290- model: The model or list of model chunks to collect parameter names from
287+ context: OptimizerConfigOverrideProviderContext which packages the scheduler
288+ configuration, optimizer configuration, and model.
291289
292290 Returns:
293291 Dictionary of ParamKey to ParamGroupOverride for the optimizer
@@ -302,7 +300,7 @@ def build_config_overrides(self, context: OptimizerConfigOverrideProviderContext
302300 # NOTE: this can be simplified once https://github.com/NVIDIA/Megatron-LM/pull/2753
303301 # is merged into dev. Then we can re-use megatron's apply_wd_to_qk_layernorm option
304302 # and call megatron.core.optimizer.get_standard_config_overrides(optimizer_config)
305- # directly for standard settings.
303+ # directly for standard settings, replacing the custom logic below for qwen3-next .
306304 no_wd_names : list [str ] = []
307305 is_qwen3_next = scheduler_config .no_weight_decay_cond_type == "qwen3_next"
308306
0 commit comments