Skip to content

fix: device mismatch when DPO validation at start with CPU offload(Nemotron)#1930

Open
RayenTian wants to merge 2 commits intomainfrom
ruit/fix_dpo_cpu_offload
Open

fix: device mismatch when DPO validation at start with CPU offload(Nemotron)#1930
RayenTian wants to merge 2 commits intomainfrom
ruit/fix_dpo_cpu_offload

Conversation

@RayenTian
Copy link
Contributor

@RayenTian RayenTian commented Feb 12, 2026

What does this PR do ?

Summary
When DPO is run with val_at_start=true and policy.dtensor_cfg.cpu_offload=true, validation can crash on MoE models (e.g. NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) with:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Issues

Related to #1922

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • Bug Fixes
    • Improved policy state management during validation to enhance consistency and behavioral alignment throughout the validation process.

@RayenTian RayenTian added the CI:L1 Run doctests, unit tests, and functional tests label Feb 12, 2026
@RayenTian RayenTian changed the title fix: fix dpo device mismatch bug when enable cpu_offload and val_at_start fix: Fix device mismatch when DPO runs validation at start with CPU offload (Nemotron MoE) Feb 12, 2026
@RayenTian RayenTian force-pushed the ruit/fix_dpo_cpu_offload branch 2 times, most recently from e77f145 to 662aebc Compare February 12, 2026 08:04
@RayenTian RayenTian force-pushed the ruit/fix_dpo_cpu_offload branch from 662aebc to f1cae7f Compare February 23, 2026 06:09
@RayenTian RayenTian added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 23, 2026
@RayenTian RayenTian changed the title fix: Fix device mismatch when DPO runs validation at start with CPU offload (Nemotron MoE) fix: device mismatch when DPO runs validation at start with CPU offload Feb 23, 2026
@RayenTian RayenTian changed the title fix: device mismatch when DPO runs validation at start with CPU offload fix: device mismatch when DPO runs validation at start with CPU offload(MoE) Feb 23, 2026
@RayenTian RayenTian marked this pull request as ready for review February 24, 2026 02:52
@RayenTian RayenTian requested a review from a team as a code owner February 24, 2026 02:52
@RayenTian RayenTian changed the title fix: device mismatch when DPO runs validation at start with CPU offload(MoE) fix: device mismatch when DPO validation at start with CPU offload(Nemotron) Feb 24, 2026
@RayenTian RayenTian added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 24, 2026
@RayenTian RayenTian requested a review from yuki-97 February 24, 2026 02:53
@RayenTian RayenTian requested a review from terrykong February 24, 2026 02:53
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 24, 2026

📝 Walkthrough

Walkthrough

The change adds a single line to invoke policy.prepare_for_training() in the validate_one_dataset method of the DPO algorithm, positioning the call after metrics initialization and before the validation batch processing loop.

Changes

Cohort / File(s) Summary
DPO Validation Preparation
nemo_rl/algorithms/dpo.py
Added call to policy.prepare_for_training() at the start of validate_one_dataset to transition the policy to training preparation mode before processing validation batches.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Test Results For Major Changes ✅ Passed This PR contains a minor change (single line addition) that is a targeted bug fix for a device mismatch issue in DPO validation with CPU offload on MoE models.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly identifies the specific bug fix (device mismatch in DPO validation at start with CPU offload) and the affected model type (Nemotron/MoE models), which directly corresponds to the changeset that adds policy.prepare_for_training() to address buffer placement issues.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ruit/fix_dpo_cpu_offload

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/algorithms/dpo.py (1)

462-463: ⚠️ Potential issue | 🟡 Minor

Misleading comment obscures the purpose of the pre-existing prepare_for_training() call.

The comment # Calculate validation metrics at line 462 has no relation to the policy.prepare_for_training() call that follows — it appears to be a stale comment from a previous reorganisation. With the new call added at line 388, two prepare_for_training() invocations now exist within validate_one_dataset, and the distinction between them (pre-validation GPU load vs. post-validation GPU restore for the next training step) is completely invisible. The comment should document the actual intent.

🛠️ Proposed fix
-        # Calculate validation metrics
-        policy.prepare_for_training()
+        # Restore GPU state after validation; with cpu_offload enabled the
+        # model may have been offloaded to CPU during forward passes, and the
+        # next training step requires everything back on GPU.
+        policy.prepare_for_training()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/algorithms/dpo.py` around lines 462 - 463, The comment "# Calculate
validation metrics" is misleading because the following call to
policy.prepare_for_training() is restoring the policy state for the upcoming
training step (not calculating metrics) and duplicates an earlier
prepare_for_training() call in validate_one_dataset; update the comment to
clearly state the intent (e.g., "Restore policy state / prepare for next
training step on GPU") and ensure it appears only where appropriate (leave the
earlier prepare_for_training() that lowers GPU usage before validation, and keep
this second call as the restore step), so that the two prepare_for_training()
usages in validate_one_dataset are clearly documented and not confused with
metric calculation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@nemo_rl/algorithms/dpo.py`:
- Line 388: In validate_one_dataset() the policy is wrongly switched into
training mode via policy.prepare_for_training(); replace that call with
policy.prepare_for_lp_inference() so the reference logprob computation in
add_ref_logprobs_to_data() → get_reference_policy_logprobs() runs in
inference/eval mode (disabling dropout/BN) and produces correct validation
metrics.

---

Outside diff comments:
In `@nemo_rl/algorithms/dpo.py`:
- Around line 462-463: The comment "# Calculate validation metrics" is
misleading because the following call to policy.prepare_for_training() is
restoring the policy state for the upcoming training step (not calculating
metrics) and duplicates an earlier prepare_for_training() call in
validate_one_dataset; update the comment to clearly state the intent (e.g.,
"Restore policy state / prepare for next training step on GPU") and ensure it
appears only where appropriate (leave the earlier prepare_for_training() that
lowers GPU usage before validation, and keep this second call as the restore
step), so that the two prepare_for_training() usages in validate_one_dataset are
clearly documented and not confused with metric calculation.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9148186 and 801c115.

📒 Files selected for processing (1)
  • nemo_rl/algorithms/dpo.py


val_metrics = defaultdict(list)
num_valid_batches = 0
policy.prepare_for_training()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Check for prepare_for_lp_inference and compare with prepare_for_training definitions.

echo "=== Searching for prepare_for_lp_inference ==="
rg -n "prepare_for_lp_inference" --type py -A 6 -B 1

echo ""
echo "=== Searching for prepare_for_training definitions ==="
rg -n "def prepare_for_training" --type py -A 6

echo ""
echo "=== All references to prepare_for_training (to understand context) ==="
rg -n "prepare_for_training" --type py | head -30

Repository: NVIDIA-NeMo/RL

Length of output: 26906


🏁 Script executed:

# Get context around line 388 in dpo.py to confirm it's in validation/inference path
sed -n '375,410p' nemo_rl/algorithms/dpo.py

Repository: NVIDIA-NeMo/RL

Length of output: 1458


🏁 Script executed:

# Get more context to see add_ref_logprobs_to_data function
sed -n '360,395p' nemo_rl/algorithms/dpo.py

Repository: NVIDIA-NeMo/RL

Length of output: 1177


🏁 Script executed:

# Find add_ref_logprobs_to_data function definition
rg -n "def add_ref_logprobs_to_data" --type py -A 15

Repository: NVIDIA-NeMo/RL

Length of output: 1280


🏁 Script executed:

# Find get_reference_policy_logprobs calls in add_ref_logprobs_to_data
sed -n '275,330p' nemo_rl/algorithms/dpo.py

Repository: NVIDIA-NeMo/RL

Length of output: 2207


🏁 Script executed:

# Verify the megatron_policy_worker implementations to confirm semantic differences
sed -n '1059,1070p' nemo_rl/models/policy/workers/megatron_policy_worker.py
sed -n '1081,1090p' nemo_rl/models/policy/workers/megatron_policy_worker.py

Repository: NVIDIA-NeMo/RL

Length of output: 871


Replace prepare_for_training() with prepare_for_lp_inference() at line 388.

The method call is incorrect for the validation context. This code is inside validate_one_dataset() which calls add_ref_logprobs_to_data()get_reference_policy_logprobs(), an inference operation. Using prepare_for_training() sets the model to training mode (.train()), but reference logprob computation requires inference mode (.eval()). The incorrect mode can cause dropout and batch normalization to behave incorrectly, leading to wrong validation metrics.

The PR description correctly identifies that prepare_for_lp_inference() should be called. Use policy.prepare_for_lp_inference() instead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_rl/algorithms/dpo.py` at line 388, In validate_one_dataset() the policy
is wrongly switched into training mode via policy.prepare_for_training();
replace that call with policy.prepare_for_lp_inference() so the reference
logprob computation in add_ref_logprobs_to_data() →
get_reference_policy_logprobs() runs in inference/eval mode (disabling
dropout/BN) and produces correct validation metrics.


val_metrics = defaultdict(list)
num_valid_batches = 0
policy.prepare_for_training()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a brief look and seems rm.py also has this issue (other algorithms should be fine), can you help to take a check and add the fix into rm.py as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue only happened on Nemotron, and not for dense model. Now for rm model, the architecture is dense, and can not load nemotron for rm. So this is not tested.

@RayenTian RayenTian force-pushed the ruit/fix_dpo_cpu_offload branch from 2e91a78 to 1215617 Compare February 27, 2026 06:26
Signed-off-by: ruit <ruit@nvidia.com>
@RayenTian RayenTian requested a review from a team as a code owner February 27, 2026 09:48
@RayenTian RayenTian added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants