Skip to content

Conversation

@gitlost-murali
Copy link
Contributor

@gitlost-murali gitlost-murali commented Nov 30, 2025

Summary

When tensor parallelism is enabled, the reference model's logits are sharded across GPUs on the vocabulary dimension. Previously, we called full_tensor() to gather the complete vocab on each GPU before computing log probabilities.

This PR adds compute_logprobs_parallel() that computes log probabilities distributedly using the log-sum-exp trick across shards.

Memory savings (measured)

Scenario Memory per GPU
Old (full_tensor + compute_logprobs) 58 GB
New (parallel logprobs) 34 GB
Saved 24 GB (~41%)

Old state usage:

current

New state (Parallel logprobs based) usage:

optimized

Tested with batch=4, seq_len=9k (1024 prompt tokens + 8192 response tokens), vocab=150k, TP=2

Changes

  • New: src/forge/util/parallel_logprobs.py - distributed log-prob computation for vocab-sharded DTensors
  • New: tests/unit_tests/util/test_parallel_logprobs.py - correctness tests against sequential implementation
  • Modified: src/forge/actors/reference_model.py - uses parallel version when TP is enabled

Implementation

Uses distributed log-softmax without gathering:

  1. All-reduce MAX for numerical stability
  2. All-reduce SUM of local exp(x - max)
  3. Each rank gathers logits only for tokens in its shard
  4. All-reduce SUM to combine (only owning rank contributes)

Testing

  • Verified results match compute_logprobs() within 1e-5 tolerance
  • Tested temperature scaling, alignment modes, numerical stability with extreme values
  • Tested 2-way vocab sharded config

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 30, 2025
@gitlost-murali gitlost-murali changed the title feat: Distributed log-prob computation for vocab-sharded reference model feat: Optimize reference model GPU usage by distributed log-prob computation on vocab-sharded logits Nov 30, 2025
@gitlost-murali gitlost-murali changed the title feat: Optimize reference model GPU usage by distributed log-prob computation on vocab-sharded logits feat: Reduce reference model memory usage with distributed log-probs comp Nov 30, 2025
@gitlost-murali gitlost-murali changed the title feat: Reduce reference model memory usage with distributed log-probs comp feat: Reduce reference model memory with with parallel logprob computation Nov 30, 2025
Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea!

Could I ask for a few things?

  1. WandB logs that show the memory saved. This is always helpful as a part of verifying the correctness.
  2. Combine the parallel_logprobs and regular logprobs in the same file. No need to split that out just yet.
  3. Look for ways that this code could be simplified and/or factored out. Claude can be very verbose :)

Looking forward to getting this landed!

@gitlost-murali
Copy link
Contributor Author

Thanks for the review @joecummings !

I refactored the code as per feedback. Less Claude footprint now :). Let me know if the code needs to be further simplified/refactored.

I attached wandb chart images in the description. Also attaching it here:

Old state usage:

current

New state (Parallel logprobs based) usage:

optimized

@gitlost-murali gitlost-murali force-pushed the optimize-ref-model-usage branch from 53ddb5b to 20f59bf Compare December 3, 2025 20:36
@gitlost-murali
Copy link
Contributor Author

Hi @felipemello1,

The unit tests were failing as pytz was missing from CI env. I rebased on main now. Looks like #618 (easy - remove pytz) takes care of this

Ran the tests locally. All pass. Can you trigger the tests again please?

Thanks!

@felipemello1
Copy link
Contributor

felipemello1 commented Dec 3, 2025

@gitlost-murali, thanks for opening the PR. Great results!

can you try to run the non-sharded version but compile F.cross_entropy? e.g.

@torch.compile()
def compute_logprobs(...):
    ...

I think that simply compiling it greatly reduces the memory, since it never materializes the intermediate activations. Maybe something to do in addition to your work and not in place of your work. I am skeptical about using the log-sum-exp directly and not F.cross_entropy, since the compiled version is highly optimized.

Also, you might be interested in checking Nathan's old PRs in torchtune: meta-pytorch/torchtune#2782

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants