Skip to content

refactor: refactor loss function#1920

Merged
yuki-97 merged 15 commits intomainfrom
yukih/refactor-loss
Mar 2, 2026
Merged

refactor: refactor loss function#1920
yuki-97 merged 15 commits intomainfrom
yukih/refactor-loss

Conversation

@yuki-97
Copy link
Contributor

@yuki-97 yuki-97 commented Feb 10, 2026

  1. Move parallel stuffs out of loss function.
  2. Add LossInputType (logit, logprob, distillation) and prepare_loss_input to convert logits to the destination loss input and measure the parallel stuffs.
  3. Update the loss file structure.
├── loss
│   ├── __init__.py
│   ├── interfaces.py
│   ├── loss_functions.py
│   ├── utils.py
│   └── wrapper.py

Test Result
https://wandb.ai/nvidia/refactor-loss-yukih?nw=0k8r2x613fml

GRPO
image
SFT Distillation
image image
DPO RM
image image

Nightly test all passed except the tests that already failed at main. #2041

Summary by CodeRabbit

  • Documentation

    • Updated module import paths in guides and documentation to reflect reorganized loss function architecture.
  • Refactor

    • Restructured loss function modules into a new hierarchical package with improved interfaces and abstractions. Introduced LossInputType enumeration for standardized loss function input specifications. Updated loss function signatures to use pre-computed log probabilities instead of raw logits. Reorganized loss utilities including input preparation and sequence packing wrappers.

@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 696f9ad (PR #1920 from yukih/refactor-loss)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@yuki-97 yuki-97 force-pushed the yukih/refactor-loss branch from 696f9ad to 54e1283 Compare February 10, 2026 12:39
Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hemildesai can you review the automodel changes
@yaoyu-33 @cuichenx to comment on the change from the megatron side (although this PR hasn't implemented that part yet)
@zpqiu to comment on the distillation changes

@yuki-97 yuki-97 force-pushed the yukih/refactor-loss branch from 54e1283 to f203c94 Compare February 26, 2026 05:35
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Feb 26, 2026
@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: 90693e1 (PR #1920 from yukih/refactor-loss)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: bdc4277 (PR #1920 from yukih/refactor-loss)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@yuki-97 yuki-97 added the CI:L1 Run doctests, unit tests, and functional tests label Feb 26, 2026
@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: a81e0cc (PR #1920 from yukih/refactor-loss)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 27, 2026
@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: d641ee9 (PR #1920 from yukih/refactor-loss)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 27, 2026
@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: 1b752f6 (PR #1920 from yukih/refactor-loss)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@yuki-97 yuki-97 marked this pull request as ready for review February 27, 2026 15:22
@yuki-97 yuki-97 requested review from a team as code owners February 27, 2026 15:22
Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally lgtm. i think this is a great change. makes our loss abstractions easier to write and grok. thanks @yuki-97 !

Since this change touches every algorithm, can you run the nightlies?

@github-actions
Copy link

github-actions bot commented Mar 2, 2026

ℹ️ File Consistency Check

Check based on commit: e363ebc (PR #1920 from yukih/refactor-loss)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

yuki-97 added 15 commits March 2, 2026 10:49
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
@yuki-97 yuki-97 force-pushed the yukih/refactor-loss branch from e363ebc to 443d7ad Compare March 2, 2026 02:49
@github-actions
Copy link

github-actions bot commented Mar 2, 2026

ℹ️ File Consistency Check

Check based on commit: 443d7ad (PR #1920 from yukih/refactor-loss)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Mar 2, 2026
@yuki-97
Copy link
Contributor Author

yuki-97 commented Mar 2, 2026

Since this change touches every algorithm, can you run the nightlies?

yea, I've run the nightly tests, all tests passed except the tests that already failed at main, so this PR should be fine.

I filed an issue with error logs: #2041. maybe assign someone to fix? @terrykong

@yuki-97 yuki-97 enabled auto-merge (squash) March 2, 2026 07:43
@yuki-97 yuki-97 merged commit dc9dce4 into main Mar 2, 2026
59 of 64 checks passed
@yuki-97 yuki-97 deleted the yukih/refactor-loss branch March 2, 2026 10:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants