-
Notifications
You must be signed in to change notification settings - Fork 69
feat: Reduce reference model memory with with parallel logprob computation #608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Reduce reference model memory with with parallel logprob computation #608
Conversation
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this idea!
Could I ask for a few things?
- WandB logs that show the memory saved. This is always helpful as a part of verifying the correctness.
- Combine the parallel_logprobs and regular logprobs in the same file. No need to split that out just yet.
- Look for ways that this code could be simplified and/or factored out. Claude can be very verbose :)
Looking forward to getting this landed!
|
Thanks for the review @joecummings ! I refactored the code as per feedback. Less Claude footprint now :). Let me know if the code needs to be further simplified/refactored. I attached wandb chart images in the description. Also attaching it here: Old state usage:
New state (Parallel logprobs based) usage:
|
…reference model This update introduces the function to compute log probabilities without gathering the full vocabulary tensor across GPUs.
53ddb5b to
20f59bf
Compare
|
Hi @felipemello1, The unit tests were failing as Ran the tests locally. All pass. Can you trigger the tests again please? Thanks! |
|
@gitlost-murali, thanks for opening the PR. Great results! can you try to run the non-sharded version but compile F.cross_entropy? e.g. I think that simply compiling it greatly reduces the memory, since it never materializes the intermediate activations. Maybe something to do in addition to your work and not in place of your work. I am skeptical about using the log-sum-exp directly and not F.cross_entropy, since the compiled version is highly optimized. Also, you might be interested in checking Nathan's old PRs in torchtune: meta-pytorch/torchtune#2782 |


Summary
When tensor parallelism is enabled, the reference model's logits are sharded across GPUs on the vocabulary dimension. Previously, we called
full_tensor()to gather the complete vocab on each GPU before computing log probabilities.This PR adds
compute_logprobs_parallel()that computes log probabilities distributedly using thelog-sum-exptrick across shards.Memory savings (measured)
Old state usage:
New state (Parallel logprobs based) usage:
Tested with batch=4, seq_len=9k (1024 prompt tokens + 8192 response tokens), vocab=150k, TP=2
Changes
src/forge/util/parallel_logprobs.py- distributed log-prob computation for vocab-sharded DTensorstests/unit_tests/util/test_parallel_logprobs.py- correctness tests against sequential implementationsrc/forge/actors/reference_model.py- uses parallel version when TP is enabledImplementation
Uses distributed log-softmax without gathering:
Testing
compute_logprobs()within 1e-5 tolerance