-
Notifications
You must be signed in to change notification settings - Fork 222
[tx] Compute prompt logprobs on per-micro-batch basis #727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the computation of prompt logprobs to be determined on a per-micro-batch basis, rather than for the entire batch. This is a sensible optimization that avoids unnecessary work when only some requests in a batch require logprobs. The changes to the test suite correctly validate this new behavior by including a mix of requests with and without logprobs enabled. The implementation is mostly correct, but I've identified a potential robustness issue where a TypeError could occur.
| all_prompt_logprobs.extend( | ||
| result.prompt_logprobs[i] if request_logprobs else None | ||
| for i, request_logprobs in enumerate(all_request_logprobs[batch_start:batch_end]) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for extending all_prompt_logprobs could lead to a TypeError if model.generate returns None for prompt_logprobs even when they were requested for the micro-batch (needs_prompt_logprobs is True). The previous code included a check for result.prompt_logprobs being truthy, which was removed, making the access to result.prompt_logprobs[i] potentially unsafe.
To make the code more robust, it's better to reintroduce a check for result.prompt_logprobs being non-None before attempting to index it. The suggested change below does this in a compact way.
| all_prompt_logprobs.extend( | |
| result.prompt_logprobs[i] if request_logprobs else None | |
| for i, request_logprobs in enumerate(all_request_logprobs[batch_start:batch_end]) | |
| ) | |
| all_prompt_logprobs.extend( | |
| result.prompt_logprobs[i] if request_logprobs and result.prompt_logprobs else None | |
| for i, request_logprobs in enumerate(all_request_logprobs[batch_start:batch_end]) | |
| ) |
In response to the comment here: #577 (comment)
Prior to this PR, we check whether any request in a batch has requested the prompt logprobs and, if so, compute prompt logprobs for all samples in the batch. In this PR, we update this check to occur for each micro batch to save some unnecessary computations.
Tested with the existing micro batch test:
test_sample_prompt_logprobs_with_microbatching