fix: device mismatch when DPO validation at start with CPU offload(Nemotron)#1930
fix: device mismatch when DPO validation at start with CPU offload(Nemotron)#1930
Conversation
e77f145 to
662aebc
Compare
662aebc to
f1cae7f
Compare
📝 WalkthroughWalkthroughThe change adds a single line to invoke Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/algorithms/dpo.py (1)
462-463:⚠️ Potential issue | 🟡 MinorMisleading comment obscures the purpose of the pre-existing
prepare_for_training()call.The comment
# Calculate validation metricsat line 462 has no relation to thepolicy.prepare_for_training()call that follows — it appears to be a stale comment from a previous reorganisation. With the new call added at line 388, twoprepare_for_training()invocations now exist withinvalidate_one_dataset, and the distinction between them (pre-validation GPU load vs. post-validation GPU restore for the next training step) is completely invisible. The comment should document the actual intent.🛠️ Proposed fix
- # Calculate validation metrics - policy.prepare_for_training() + # Restore GPU state after validation; with cpu_offload enabled the + # model may have been offloaded to CPU during forward passes, and the + # next training step requires everything back on GPU. + policy.prepare_for_training()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/algorithms/dpo.py` around lines 462 - 463, The comment "# Calculate validation metrics" is misleading because the following call to policy.prepare_for_training() is restoring the policy state for the upcoming training step (not calculating metrics) and duplicates an earlier prepare_for_training() call in validate_one_dataset; update the comment to clearly state the intent (e.g., "Restore policy state / prepare for next training step on GPU") and ensure it appears only where appropriate (leave the earlier prepare_for_training() that lowers GPU usage before validation, and keep this second call as the restore step), so that the two prepare_for_training() usages in validate_one_dataset are clearly documented and not confused with metric calculation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@nemo_rl/algorithms/dpo.py`:
- Line 388: In validate_one_dataset() the policy is wrongly switched into
training mode via policy.prepare_for_training(); replace that call with
policy.prepare_for_lp_inference() so the reference logprob computation in
add_ref_logprobs_to_data() → get_reference_policy_logprobs() runs in
inference/eval mode (disabling dropout/BN) and produces correct validation
metrics.
---
Outside diff comments:
In `@nemo_rl/algorithms/dpo.py`:
- Around line 462-463: The comment "# Calculate validation metrics" is
misleading because the following call to policy.prepare_for_training() is
restoring the policy state for the upcoming training step (not calculating
metrics) and duplicates an earlier prepare_for_training() call in
validate_one_dataset; update the comment to clearly state the intent (e.g.,
"Restore policy state / prepare for next training step on GPU") and ensure it
appears only where appropriate (leave the earlier prepare_for_training() that
lowers GPU usage before validation, and keep this second call as the restore
step), so that the two prepare_for_training() usages in validate_one_dataset are
clearly documented and not confused with metric calculation.
|
|
||
| val_metrics = defaultdict(list) | ||
| num_valid_batches = 0 | ||
| policy.prepare_for_training() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Check for prepare_for_lp_inference and compare with prepare_for_training definitions.
echo "=== Searching for prepare_for_lp_inference ==="
rg -n "prepare_for_lp_inference" --type py -A 6 -B 1
echo ""
echo "=== Searching for prepare_for_training definitions ==="
rg -n "def prepare_for_training" --type py -A 6
echo ""
echo "=== All references to prepare_for_training (to understand context) ==="
rg -n "prepare_for_training" --type py | head -30Repository: NVIDIA-NeMo/RL
Length of output: 26906
🏁 Script executed:
# Get context around line 388 in dpo.py to confirm it's in validation/inference path
sed -n '375,410p' nemo_rl/algorithms/dpo.pyRepository: NVIDIA-NeMo/RL
Length of output: 1458
🏁 Script executed:
# Get more context to see add_ref_logprobs_to_data function
sed -n '360,395p' nemo_rl/algorithms/dpo.pyRepository: NVIDIA-NeMo/RL
Length of output: 1177
🏁 Script executed:
# Find add_ref_logprobs_to_data function definition
rg -n "def add_ref_logprobs_to_data" --type py -A 15Repository: NVIDIA-NeMo/RL
Length of output: 1280
🏁 Script executed:
# Find get_reference_policy_logprobs calls in add_ref_logprobs_to_data
sed -n '275,330p' nemo_rl/algorithms/dpo.pyRepository: NVIDIA-NeMo/RL
Length of output: 2207
🏁 Script executed:
# Verify the megatron_policy_worker implementations to confirm semantic differences
sed -n '1059,1070p' nemo_rl/models/policy/workers/megatron_policy_worker.py
sed -n '1081,1090p' nemo_rl/models/policy/workers/megatron_policy_worker.pyRepository: NVIDIA-NeMo/RL
Length of output: 871
Replace prepare_for_training() with prepare_for_lp_inference() at line 388.
The method call is incorrect for the validation context. This code is inside validate_one_dataset() which calls add_ref_logprobs_to_data() → get_reference_policy_logprobs(), an inference operation. Using prepare_for_training() sets the model to training mode (.train()), but reference logprob computation requires inference mode (.eval()). The incorrect mode can cause dropout and batch normalization to behave incorrectly, leading to wrong validation metrics.
The PR description correctly identifies that prepare_for_lp_inference() should be called. Use policy.prepare_for_lp_inference() instead.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@nemo_rl/algorithms/dpo.py` at line 388, In validate_one_dataset() the policy
is wrongly switched into training mode via policy.prepare_for_training();
replace that call with policy.prepare_for_lp_inference() so the reference
logprob computation in add_ref_logprobs_to_data() →
get_reference_policy_logprobs() runs in inference/eval mode (disabling
dropout/BN) and produces correct validation metrics.
|
|
||
| val_metrics = defaultdict(list) | ||
| num_valid_batches = 0 | ||
| policy.prepare_for_training() |
There was a problem hiding this comment.
I took a brief look and seems rm.py also has this issue (other algorithms should be fine), can you help to take a check and add the fix into rm.py as well?
There was a problem hiding this comment.
This issue only happened on Nemotron, and not for dense model. Now for rm model, the architecture is dense, and can not load nemotron for rm. So this is not tested.
801c115 to
ad0a20a
Compare
Signed-off-by: ruit <ruit@nvidia.com>
2e91a78 to
1215617
Compare
Signed-off-by: ruit <ruit@nvidia.com>
What does this PR do ?
Summary
When DPO is run with val_at_start=true and policy.dtensor_cfg.cpu_offload=true, validation can crash on MoE models (e.g. NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) with:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!This PR fixes the bug by calling policy.prepare_for_training() before running the initial validation when val_at_start is enabled, so that all buffers (including MoE gate buffers) are on CUDA before any reference-policy logprob computation.
Issues
Related to #1922
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit