Skip to content

GPT-OSS on-policy distillation fails during checkpoint load (Missing key: model.layers.0.mlp.experts.down_proj) #1810

@uygarmv

Description

@uygarmv

Summary

Trying to use GPT-OSS with on-policy distillation (per #1679: “should work by changing model names”) fails during model initialization when loading the base checkpoint via torch.distributed.checkpoint (DCP). After working around a Transformers attention-impl error by forcing attn_implementation="eager", the run still crashes with a missing key in the checkpoint state_dict:

RuntimeError: Missing key in checkpoint state_dict: model.layers.0.mlp.experts.down_proj.


Expected behavior

Running examples/run_distillation_math.py with GPT-OSS teacher/student should initialize successfully and start distillation training (or at least reach the first training step) after swapping model names as described in #1679.


Actual behavior

The Ray teacher actors die during initialization (DTensorPolicyWorkerV2.__init__) while loading the base model checkpoint. All ranks report a CheckpointException, with the root cause being a missing key in the checkpoint.


Exact command

uv run python examples/run_distillation_math.py \
  policy.model_name="openai/gpt-oss-20b" \
  teacher.model_name="openai/gpt-oss-20b" \
  cluster.gpus_per_node=8

Changes made to try GPT-OSS

1) Config change

In distillation_math.yaml:
policy:
dtensor_cfg:
automodel_kwargs:
force_hf: true

  1. Workaround for Transformers attention error

Original error:

GptOssForCausalLM does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet...
... load your model with attn_implementation="eager"

I applied the suggested workaround by forcing eager attention in nemo_rl/models/automodel/setup.py:
• set attn_impl = "eager"
• and passed attn_implementation="eager" into config load:

model_config = AutoConfig.from_pretrained(
model_name,
torch_dtype=torch.float32, # Always load in float32 for master weights
trust_remote_code=True,
attn_implementation="eager",
**hf_config_overrides,
)

Error log (root cause)

This is the first actionable failure I see (repeated across ranks):

RuntimeError: Missing key in checkpoint state_dict: model.layers.0.mlp.experts.down_proj.

Full trace excerpt (trimmed to the relevant portion):

ray::teacher-0-7:DTensorPolicyWorkerV2.init()
File ".../dtensor_policy_worker_v2.py", line 228, in init
model_and_optimizer_state = setup_model_and_optimizer(...)
File ".../nemo_rl/models/automodel/setup.py", line 484, in setup_model_and_optimizer
checkpoint_manager.load_base_model(...)
File ".../nemo_rl/utils/automodel_checkpoint.py", line 237, in load_base_model
self.checkpointer.load_base_model(...)
File ".../nemo_automodel/components/checkpoint/checkpointing.py", line 369, in load_base_model
self.load_model(...)
File ".../checkpointing.py", line 474, in _do_load
dcp.load(state_dict, checkpoint_id=path, storage_reader=storage_reader)
...
File ".../torch/distributed/checkpoint/default_planner.py", line 471, in create_default_local_load_plan
raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
RuntimeError: Missing key in checkpoint state_dict: model.layers.0.mlp.experts.down_proj.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions