Skip to content

skip_reference_policy_logprobs_calculation=true crashes training with RuntimeError / NameError #1968

@dmvevents

Description

@dmvevents

Bug Description

Setting grpo.skip_reference_policy_logprobs_calculation=true with loss_fn.reference_policy_kl_penalty=0 crashes training in NeMo RL v0.5.0. There are two distinct bugs:

Bug 1: use_reference_model() generator crash (sync GRPO)

When skip_reference_policy_logprobs_calculation=true, the reference model is never initialized (init_reference_model is not called, so self.reference_state_dict does not exist on the worker). However, the sync GRPO code path at grpo.py L1684-L1692 correctly guards the call:

if not master_config["grpo"].get("skip_reference_policy_logprobs_calculation"):
    train_data["reference_policy_logprobs"] = (
        policy.get_reference_policy_logprobs(logprob_data, timer=timer)["reference_logprobs"]
    )

But base_policy_worker.py L125-149 implements get_reference_policy_logprobs() using with self.use_reference_model():, and use_reference_model() in megatron_policy_worker.py L524-573 is a @contextmanager generator that accesses self.reference_state_dict (L546). When the reference model was never initialized, this attribute does not exist, causing AttributeError. Even if reference_state_dict is set to None, the generator fails because load_state_dict(None) raises, and the @contextmanager decorator converts this to:

RuntimeError: generator didn't yield

Bug 2: Async GRPO path ignores the skip flag entirely

The async GRPO code path at grpo.py L2699-L2704 unconditionally calls get_reference_policy_logprobs() without checking skip_reference_policy_logprobs_calculation:

reference_logprobs = policy.get_reference_policy_logprobs(
    train_data,
    timer=timer,
)["reference_logprobs"]
train_data["reference_policy_logprobs"] = reference_logprobs

This always crashes when reference model was not initialized.

Bug 3: Missing reference_policy_logprobs in train_data causes shape mismatch

Even when Bug 1 is worked around (by catching the exception), train_data["reference_policy_logprobs"] is never set when the call is skipped. Downstream, policy.train(train_data) expects this key to exist and have the same shape as other logprob tensors (e.g., [batch_size, seq_len]). Its absence or None value causes assertion failures or shape mismatches in the loss computation.

Steps to Reproduce

python examples/run_grpo.py \
    +grpo.skip_reference_policy_logprobs_calculation=true \
    ++loss_fn.reference_policy_kl_penalty=0 \
    # ... other standard GRPO config

Error Messages

RuntimeError: generator didn't yield

from base_policy_worker.py:143 (with self.use_reference_model():)

or:

AttributeError: 'MegatronPolicyWorker' object has no attribute 'reference_state_dict'

Environment

  • NeMo RL v0.5.0 (nvcr.io/nvidia/nemo-rl:v0.5.0)
  • Tested on both P5.48xlarge (H100) and P5en.48xlarge (H200)
  • PyTorch 2.x, CUDA 12.x

Suggested Fix

1. Guard use_reference_model() in megatron_policy_worker.py

Add an early return when no reference model is initialized:

@contextmanager
def use_reference_model(self):
    if not hasattr(self, "reference_state_dict") or self.reference_state_dict is None:
        yield  # no-op context manager
        return
    # ... existing implementation

2. Guard get_reference_policy_logprobs() in base_policy_worker.py

Return zeros when no reference model exists:

def get_reference_policy_logprobs(self, *, data, micro_batch_size=None):
    if not hasattr(self, "reference_state_dict") or self.reference_state_dict is None:
        # Return zeros matching expected shape when reference model is skipped
        logprobs = self.get_logprobs(data=data, micro_batch_size=micro_batch_size)
        return_data = BatchedDataDict[ReferenceLogprobOutputSpec]()
        return_data["reference_logprobs"] = torch.zeros_like(logprobs["logprobs"]).cpu()
        return return_data
    # ... existing implementation

3. Add skip check in async GRPO path (grpo.py L2699)

Mirror the sync path's guard:

if not master_config["grpo"].get("skip_reference_policy_logprobs_calculation"):
    reference_logprobs = policy.get_reference_policy_logprobs(
        train_data, timer=timer,
    )["reference_logprobs"]
    train_data["reference_policy_logprobs"] = reference_logprobs
else:
    # Set zeros to maintain expected tensor shape for downstream loss computation
    train_data["reference_policy_logprobs"] = torch.zeros_like(fprop_logprobs)

Workaround

We are currently using runtime monkey-patches to work around this (two scripts applied at container startup):

  1. Worker patch: Injects early-return guards into use_reference_model(), use_training_model(), and get_reference_policy_logprobs() when self.reference_state_dict is None
  2. GRPO patch: Wraps get_reference_policy_logprobs() in try/except, and injects zeros_like for reference_policy_logprobs before policy.train() when the value is None or has wrong shape

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions