Skip to content

Conversation

@tyler-griggs
Copy link
Member

@tyler-griggs tyler-griggs commented Dec 1, 2025

In response to the comment here: #577 (comment)

Prior to this PR, we check whether any request in a batch has requested the prompt logprobs and, if so, compute prompt logprobs for all samples in the batch. In this PR, we update this check to occur for each micro batch to save some unnecessary computations.

Tested with the existing micro batch test: test_sample_prompt_logprobs_with_microbatching

@tyler-griggs tyler-griggs added the tx label Dec 1, 2025
@tyler-griggs tyler-griggs marked this pull request as ready for review December 1, 2025 21:34
@pcmoritz
Copy link
Collaborator

pcmoritz commented Dec 2, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the computation of prompt logprobs to be determined on a per-micro-batch basis, rather than for the entire batch. This is a sensible optimization that avoids unnecessary work when only some requests in a batch require logprobs. The changes to the test suite correctly validate this new behavior by including a mix of requests with and without logprobs enabled. The implementation is mostly correct, but I've identified a potential robustness issue where a TypeError could occur.

Comment on lines +685 to +688
all_prompt_logprobs.extend(
result.prompt_logprobs[i] if request_logprobs else None
for i, request_logprobs in enumerate(all_request_logprobs[batch_start:batch_end])
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for extending all_prompt_logprobs could lead to a TypeError if model.generate returns None for prompt_logprobs even when they were requested for the micro-batch (needs_prompt_logprobs is True). The previous code included a check for result.prompt_logprobs being truthy, which was removed, making the access to result.prompt_logprobs[i] potentially unsafe.

To make the code more robust, it's better to reintroduce a check for result.prompt_logprobs being non-None before attempting to index it. The suggested change below does this in a compact way.

Suggested change
all_prompt_logprobs.extend(
result.prompt_logprobs[i] if request_logprobs else None
for i, request_logprobs in enumerate(all_request_logprobs[batch_start:batch_end])
)
all_prompt_logprobs.extend(
result.prompt_logprobs[i] if request_logprobs and result.prompt_logprobs else None
for i, request_logprobs in enumerate(all_request_logprobs[batch_start:batch_end])
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants