-
Notifications
You must be signed in to change notification settings - Fork 270
Description
Bug Description
Setting grpo.skip_reference_policy_logprobs_calculation=true with loss_fn.reference_policy_kl_penalty=0 crashes training in NeMo RL v0.5.0. There are two distinct bugs:
Bug 1: use_reference_model() generator crash (sync GRPO)
When skip_reference_policy_logprobs_calculation=true, the reference model is never initialized (init_reference_model is not called, so self.reference_state_dict does not exist on the worker). However, the sync GRPO code path at grpo.py L1684-L1692 correctly guards the call:
if not master_config["grpo"].get("skip_reference_policy_logprobs_calculation"):
train_data["reference_policy_logprobs"] = (
policy.get_reference_policy_logprobs(logprob_data, timer=timer)["reference_logprobs"]
)But base_policy_worker.py L125-149 implements get_reference_policy_logprobs() using with self.use_reference_model():, and use_reference_model() in megatron_policy_worker.py L524-573 is a @contextmanager generator that accesses self.reference_state_dict (L546). When the reference model was never initialized, this attribute does not exist, causing AttributeError. Even if reference_state_dict is set to None, the generator fails because load_state_dict(None) raises, and the @contextmanager decorator converts this to:
RuntimeError: generator didn't yield
Bug 2: Async GRPO path ignores the skip flag entirely
The async GRPO code path at grpo.py L2699-L2704 unconditionally calls get_reference_policy_logprobs() without checking skip_reference_policy_logprobs_calculation:
reference_logprobs = policy.get_reference_policy_logprobs(
train_data,
timer=timer,
)["reference_logprobs"]
train_data["reference_policy_logprobs"] = reference_logprobsThis always crashes when reference model was not initialized.
Bug 3: Missing reference_policy_logprobs in train_data causes shape mismatch
Even when Bug 1 is worked around (by catching the exception), train_data["reference_policy_logprobs"] is never set when the call is skipped. Downstream, policy.train(train_data) expects this key to exist and have the same shape as other logprob tensors (e.g., [batch_size, seq_len]). Its absence or None value causes assertion failures or shape mismatches in the loss computation.
Steps to Reproduce
python examples/run_grpo.py \
+grpo.skip_reference_policy_logprobs_calculation=true \
++loss_fn.reference_policy_kl_penalty=0 \
# ... other standard GRPO configError Messages
RuntimeError: generator didn't yield
from base_policy_worker.py:143 (with self.use_reference_model():)
or:
AttributeError: 'MegatronPolicyWorker' object has no attribute 'reference_state_dict'
Environment
- NeMo RL v0.5.0 (
nvcr.io/nvidia/nemo-rl:v0.5.0) - Tested on both P5.48xlarge (H100) and P5en.48xlarge (H200)
- PyTorch 2.x, CUDA 12.x
Suggested Fix
1. Guard use_reference_model() in megatron_policy_worker.py
Add an early return when no reference model is initialized:
@contextmanager
def use_reference_model(self):
if not hasattr(self, "reference_state_dict") or self.reference_state_dict is None:
yield # no-op context manager
return
# ... existing implementation2. Guard get_reference_policy_logprobs() in base_policy_worker.py
Return zeros when no reference model exists:
def get_reference_policy_logprobs(self, *, data, micro_batch_size=None):
if not hasattr(self, "reference_state_dict") or self.reference_state_dict is None:
# Return zeros matching expected shape when reference model is skipped
logprobs = self.get_logprobs(data=data, micro_batch_size=micro_batch_size)
return_data = BatchedDataDict[ReferenceLogprobOutputSpec]()
return_data["reference_logprobs"] = torch.zeros_like(logprobs["logprobs"]).cpu()
return return_data
# ... existing implementation3. Add skip check in async GRPO path (grpo.py L2699)
Mirror the sync path's guard:
if not master_config["grpo"].get("skip_reference_policy_logprobs_calculation"):
reference_logprobs = policy.get_reference_policy_logprobs(
train_data, timer=timer,
)["reference_logprobs"]
train_data["reference_policy_logprobs"] = reference_logprobs
else:
# Set zeros to maintain expected tensor shape for downstream loss computation
train_data["reference_policy_logprobs"] = torch.zeros_like(fprop_logprobs)Workaround
We are currently using runtime monkey-patches to work around this (two scripts applied at container startup):
- Worker patch: Injects early-return guards into
use_reference_model(),use_training_model(), andget_reference_policy_logprobs()whenself.reference_state_dict is None - GRPO patch: Wraps
get_reference_policy_logprobs()in try/except, and injectszeros_likeforreference_policy_logprobsbeforepolicy.train()when the value is None or has wrong shape