Skip to content

Commit b4399e8

Browse files
authored
minor fix for megatron compatibility (#1149)
1 parent 54762b7 commit b4399e8

File tree

4 files changed

+18
-24
lines changed

4 files changed

+18
-24
lines changed

slime/backends/megatron_utils/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def loss_function(
771771

772772
return (
773773
loss,
774-
num_tokens if args.calculate_per_token_loss else 1,
774+
torch.tensor(num_tokens if args.calculate_per_token_loss else 1, device=logits.device),
775775
{
776776
"keys": list(log.keys()),
777777
"values": torch.tensor(

slime/backends/megatron_utils/model.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,6 @@ def get_optimizer_param_scheduler(args: Namespace, optimizer: MegatronOptimizer)
8484
def setup_model_and_optimizer(
8585
args: Namespace,
8686
role: str = "actor",
87-
no_wd_decay_cond: Callable[..., bool] | None = None,
88-
scale_lr_cond: Callable[..., bool] | None = None,
89-
lr_mult: float = 1.0,
9087
) -> tuple[list[DDP], MegatronOptimizer, OptimizerParamScheduler]:
9188
"""Build model(s), wrap with DDP, and construct optimizer and scheduler.
9289
@@ -119,11 +116,8 @@ def setup_model_and_optimizer(
119116
config.timers = None
120117

121118
optimizer = get_megatron_optimizer(
122-
config,
123-
model,
124-
no_wd_decay_cond,
125-
scale_lr_cond,
126-
lr_mult,
119+
config=config,
120+
model_chunks=model,
127121
use_gloo_process_groups=args.enable_gloo_process_groups,
128122
)
129123
opt_param_scheduler = get_optimizer_param_scheduler(args, optimizer)

slime/backends/megatron_utils/model_provider.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,19 @@ def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage
9494
# Define the decoder layer spec
9595
if use_te:
9696
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
97-
args.num_experts,
98-
args.moe_grouped_gemm,
99-
args.qk_layernorm,
100-
args.multi_latent_attention,
101-
args.moe_use_legacy_grouped_gemm,
97+
num_experts=args.num_experts,
98+
moe_grouped_gemm=args.moe_grouped_gemm,
99+
qk_layernorm=args.qk_layernorm,
100+
multi_latent_attention=args.multi_latent_attention,
101+
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
102102
)
103103
else:
104104
transformer_layer_spec = get_gpt_layer_local_spec(
105-
args.num_experts,
106-
args.moe_grouped_gemm,
107-
args.qk_layernorm,
108-
args.multi_latent_attention,
109-
args.moe_use_legacy_grouped_gemm,
105+
num_experts=args.num_experts,
106+
moe_grouped_gemm=args.moe_grouped_gemm,
107+
qk_layernorm=args.qk_layernorm,
108+
multi_latent_attention=args.multi_latent_attention,
109+
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
110110
)
111111

112112
build_model_context = nullcontext

slime_plugins/models/glm4.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
def get_glm_spec(args, config, vp_stage):
55
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
6-
args.num_experts,
7-
args.moe_grouped_gemm,
8-
args.qk_layernorm,
9-
args.multi_latent_attention,
10-
args.moe_use_legacy_grouped_gemm,
6+
num_experts=args.num_experts,
7+
moe_grouped_gemm=args.moe_grouped_gemm,
8+
qk_layernorm=args.qk_layernorm,
9+
multi_latent_attention=args.multi_latent_attention,
10+
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
1111
post_self_attn_layernorm=args.post_self_attn_layernorm,
1212
post_mlp_layernorm=args.post_mlp_layernorm,
1313
)

0 commit comments

Comments
 (0)