-
Notifications
You must be signed in to change notification settings - Fork 268
Description
Summary
Trying to use GPT-OSS with on-policy distillation (per #1679: “should work by changing model names”) fails during model initialization when loading the base checkpoint via torch.distributed.checkpoint (DCP). After working around a Transformers attention-impl error by forcing attn_implementation="eager", the run still crashes with a missing key in the checkpoint state_dict:
RuntimeError: Missing key in checkpoint state_dict: model.layers.0.mlp.experts.down_proj.
Expected behavior
Running examples/run_distillation_math.py with GPT-OSS teacher/student should initialize successfully and start distillation training (or at least reach the first training step) after swapping model names as described in #1679.
Actual behavior
The Ray teacher actors die during initialization (DTensorPolicyWorkerV2.__init__) while loading the base model checkpoint. All ranks report a CheckpointException, with the root cause being a missing key in the checkpoint.
Exact command
uv run python examples/run_distillation_math.py \
policy.model_name="openai/gpt-oss-20b" \
teacher.model_name="openai/gpt-oss-20b" \
cluster.gpus_per_node=8
Changes made to try GPT-OSS
1) Config change
In distillation_math.yaml:
policy:
dtensor_cfg:
automodel_kwargs:
force_hf: true
- Workaround for Transformers attention error
Original error:
GptOssForCausalLM does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet...
... load your model with attn_implementation="eager"
I applied the suggested workaround by forcing eager attention in nemo_rl/models/automodel/setup.py:
• set attn_impl = "eager"
• and passed attn_implementation="eager" into config load:
model_config = AutoConfig.from_pretrained(
model_name,
torch_dtype=torch.float32, # Always load in float32 for master weights
trust_remote_code=True,
attn_implementation="eager",
**hf_config_overrides,
)
⸻
Error log (root cause)
This is the first actionable failure I see (repeated across ranks):
RuntimeError: Missing key in checkpoint state_dict: model.layers.0.mlp.experts.down_proj.
Full trace excerpt (trimmed to the relevant portion):
ray::teacher-0-7:DTensorPolicyWorkerV2.init()
File ".../dtensor_policy_worker_v2.py", line 228, in init
model_and_optimizer_state = setup_model_and_optimizer(...)
File ".../nemo_rl/models/automodel/setup.py", line 484, in setup_model_and_optimizer
checkpoint_manager.load_base_model(...)
File ".../nemo_rl/utils/automodel_checkpoint.py", line 237, in load_base_model
self.checkpointer.load_base_model(...)
File ".../nemo_automodel/components/checkpoint/checkpointing.py", line 369, in load_base_model
self.load_model(...)
File ".../checkpointing.py", line 474, in _do_load
dcp.load(state_dict, checkpoint_id=path, storage_reader=storage_reader)
...
File ".../torch/distributed/checkpoint/default_planner.py", line 471, in create_default_local_load_plan
raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
RuntimeError: Missing key in checkpoint state_dict: model.layers.0.mlp.experts.down_proj.