Skip to content

perf: Reduce memory footprint for ChunkedDistribuedLogProb#1895

Open
nujoug wants to merge 3 commits intomainfrom
mloh/loss_fn_memory_footprint
Open

perf: Reduce memory footprint for ChunkedDistribuedLogProb#1895
nujoug wants to merge 3 commits intomainfrom
mloh/loss_fn_memory_footprint

Conversation

@nujoug
Copy link

@nujoug nujoug commented Feb 6, 2026

What does this PR do ?

Reduces the peak memory footprint when using chunk during loss function. (when sequence packing is disabled)

Issues

Improve loss function memory usage

Details

Problem

Current approach stores each chunk's gradient in a list, followed by torch.cat at the end to return the entire gradient tensor. This resulted in at least 2 copies of gradient tensor to exist during peak.

Screenshot 2026-02-06 at 10 45 31 AM

Modification

Preallocate a gradient tensor and copy each chunk's gradient inplace.

Screenshot 2026-02-06 at 10 57 57 AM

Additional Gains

Reducing the chunk size observes less reduction in peak memory in current approach due to some intermediate tensors overlapping (lazy delete).

Screenshot 2026-02-06 at 11 09 39 AM Screenshot 2026-02-06 at 11 14 37 AM

With explicit deallocation of intermediate tensor (using del), the memory footprint reduction is more significant (-0.4GiB to -0.6GiB)

Screenshot 2026-02-06 at 11 28 39 AM Screenshot 2026-02-06 at 11 27 07 AM

Caveats

The modification does not reduce the peak memory when sequence packing is enable. This is because SequencePackingLossWrapper is being used and default torch.autograd handles the backprop resulting in this undesirable behavior.

Screenshot 2026-02-06 at 11 36 53 AM

This issue should be able to solve this once there is a customize torch.autograd.Function that can handle sequence packing.

Summary by CodeRabbit

  • Performance Improvements

    • Reduced memory use and improved efficiency during backward/gradient computation by switching to preallocated in-place updates and explicit cleanup, lowering peak memory and improving stability.
  • Tests

    • Added comprehensive distributed/chunked entropy tests (forward and optional backward) with multi-GPU validation to ensure numerical parity with baseline.

@nujoug nujoug self-assigned this Feb 6, 2026
@nujoug nujoug force-pushed the mloh/loss_fn_memory_footprint branch from aaa5b0b to 3a6cd1a Compare February 6, 2026 18:31
@nujoug nujoug changed the title perf: Reduce memory footprint for ChunkedDistribuedLogProb Draft: perf: Reduce memory footprint for ChunkedDistribuedLogProb Feb 6, 2026
@nujoug nujoug linked an issue Feb 6, 2026 that may be closed by this pull request
@nujoug nujoug force-pushed the mloh/loss_fn_memory_footprint branch from 3a6cd1a to d1ce6b9 Compare February 6, 2026 19:47
@nujoug nujoug changed the title Draft: perf: Reduce memory footprint for ChunkedDistribuedLogProb perf: Reduce memory footprint for ChunkedDistribuedLogProb Feb 6, 2026
@nujoug nujoug marked this pull request as ready for review February 6, 2026 19:49
@nujoug nujoug requested a review from a team as a code owner February 6, 2026 19:49
@nujoug nujoug force-pushed the mloh/loss_fn_memory_footprint branch from d1ce6b9 to f10e4e4 Compare February 6, 2026 19:51
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

📝 Walkthrough

Walkthrough

Replaced list-based per-chunk gradient accumulation in distributed log-prob/backward implementations with preallocated grad_input tensors and in-place per-chunk copy_/mul_ updates; added explicit deletion of temporary tensors. Added comprehensive chunked-distributed entropy tests using Ray actors and GPU sharding.

Changes

Cohort / File(s) Summary
Backward Pass & Chunked Logprob
nemo_rl/distributed/model_utils.py
Refactored multiple backward implementations (including ChunkedDistributedLogprob) to preallocate a grad_input sized like vocab_parallel_logits, write per-chunk via grad_input_chunk.copy_/mul_, remove Python-list accumulation/concat, and explicitly del temporaries after each chunk. Single-chunk path updated to use a preallocated chunk view.
Chunked Entropy Tests & Ray Actors
tests/unit/distributed/test_model_utils.py
Added ChunkedDistributedEntropy tests: new ChunkedDistributedEntropyTestActor (Ray remote), actor lifecycle registry constant, fixture to register the actor, and parameterized test_chunked_distributed_entropy covering tp_size, chunk_size, and inference_only modes; compares chunked partitioned results (forward and optional backward) against full baseline across GPU shards. Also updated test imports to include ChunkedDistributedEntropy.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: reducing memory footprint for ChunkedDistributedLogprob through preallocation instead of list accumulation and torch.cat.
Test Results For Major Changes ✅ Passed PR comprehensively documents test results and performance metrics for memory optimization changes, including torch.cat removal in backward methods, preallocation pattern implementation, and extensive test validation with numerical tolerance checks.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch mloh/loss_fn_memory_footprint

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@nemo_rl/distributed/model_utils.py`:
- Line 223: The preallocated grad_input created with
torch.empty_like(vocab_parallel_logits) inherits bf16/fp16 and thus silently
truncates float32 gradient arithmetic; change the allocation of grad_input in
model_utils.py so it is explicitly float32 (e.g., create an empty tensor with
the same shape/device but dtype=torch.float32) so the subsequent operations
(is_chosen.float().sub_(softmax_output), the copy_ into grad_input, and the mul_
call) run in float32 and preserve gradient precision; update the grad_input
creation site (the grad_input variable near the vocab_parallel_logits usage) to
allocate float32 and ensure device/shape match.
🧹 Nitpick comments (1)
nemo_rl/distributed/model_utils.py (1)

209-260: Consider applying the same preallocation pattern to ChunkedDistributedGatherLogprob.backward and ChunkedDistributedEntropy.backward.

Both sibling backward methods (lines 334–386 and 1041–1060) still use the list-accumulate + torch.cat pattern. They'd benefit from the same preallocation optimization for consistency and memory reduction.

@nujoug nujoug force-pushed the mloh/loss_fn_memory_footprint branch from 7725656 to ac29781 Compare February 6, 2026 21:21
@nujoug
Copy link
Author

nujoug commented Feb 6, 2026

Nitpick comments (1)

nemo_rl/distributed/model_utils.py (1)> 209-260: Consider applying the same preallocation pattern to ChunkedDistributedGatherLogprob.backward and ChunkedDistributedEntropy.backward.

Both sibling backward methods (lines 334–386 and 1041–1060) still use the list-accumulate + torch.cat pattern. They'd benefit from the same preallocation optimization for consistency and memory reduction.

Should we implement this to ChunkedDistributedGatherLogprob.backward and ChunkedDistributedEntropy.backward?

Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
Signed-off-by: mloh <mloh@nvidia.com>
@nujoug nujoug force-pushed the mloh/loss_fn_memory_footprint branch from ac29781 to df5e4dc Compare February 17, 2026 19:14
@nujoug nujoug requested a review from a team as a code owner February 17, 2026 19:14
@nujoug nujoug added the CI:L2 Run doctests, unit tests, functional tests, and convergence tests label Feb 17, 2026
@nujoug
Copy link
Author

nujoug commented Feb 17, 2026

@CodeRabbit review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 17, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/unit/distributed/test_model_utils.py (1)

1048-1048: Long line — wrap for consistency with other FQN definitions in this file.

Other FQN constants (e.g., lines 513–515, 869–871) use parenthesized multi-line formatting.

Suggested wrap
-CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN = f"{ChunkedDistributedEntropyTestActor.__module__}.ChunkedDistributedEntropyTestActor"
+CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN = (
+    f"{ChunkedDistributedEntropyTestActor.__module__}.ChunkedDistributedEntropyTestActor"
+)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/distributed/test_model_utils.py` at line 1048, The constant
CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN is a too-long single-line f-string;
update its definition to use the same parenthesized multi-line wrapping style
used by other FQN constants (referencing ChunkedDistributedEntropyTestActor and
its __module__) so the string is split across lines for readability and
consistency with surrounding FQN definitions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/unit/distributed/test_model_utils.py`:
- Line 1048: The constant CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN is a
too-long single-line f-string; update its definition to use the same
parenthesized multi-line wrapping style used by other FQN constants (referencing
ChunkedDistributedEntropyTestActor and its __module__) so the string is split
across lines for readability and consistency with surrounding FQN definitions.

@nujoug
Copy link
Author

nujoug commented Feb 23, 2026

@terrykong

Hi Terry, can you look at the codecov/patch failure? tests/unit/distributed/test_model_utils.py should have already covered the highlighted lines.

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L2 Run doctests, unit tests, functional tests, and convergence tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Improve loss function memory usage

1 participant