File tree Expand file tree Collapse file tree 2 files changed +26
-1
lines changed
Expand file tree Collapse file tree 2 files changed +26
-1
lines changed Original file line number Diff line number Diff line change 1+ # SPDX-License-Identifier: Apache-2.0
2+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ """Some utilities for logprobs, including logits."""
4+
5+ import torch
6+
7+
8+ @torch .compile (dynamic = True )
9+ def batched_count_greater_than (x : torch .Tensor ,
10+ values : torch .Tensor ) -> torch .Tensor :
11+ """
12+ Counts elements in each row of x that are greater than the corresponding
13+ value in values. Use torch.compile to generate an optimized kernel for
14+ this function. otherwise, it will create additional copies of the input
15+ tensors and cause memory issues.
16+
17+ Args:
18+ x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
19+ values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
20+
21+ Returns:
22+ torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
23+ """
24+ return (x >= values ).sum (- 1 )
Original file line number Diff line number Diff line change 99from vllm .v1 .outputs import LogprobsTensors , SamplerOutput
1010from vllm .v1 .sample .metadata import SamplingMetadata
1111from vllm .v1 .sample .ops .bad_words import apply_bad_words
12+ from vllm .v1 .sample .ops .logprobs import batched_count_greater_than
1213from vllm .v1 .sample .ops .penalties import apply_all_penalties
1314from vllm .v1 .sample .ops .topk_topp_sampler import TopKTopPSampler
1415
@@ -174,7 +175,7 @@ def gather_logprobs(
174175 token_logprobs = logprobs .gather (- 1 , token_ids )
175176
176177 # Compute the ranks of the actual token.
177- token_ranks = (logprobs >= token_logprobs ). sum ( - 1 )
178+ token_ranks = batched_count_greater_than (logprobs , token_logprobs )
178179
179180 # Concatenate together with the topk.
180181 indices = torch .cat ((token_ids , topk_indices ), dim = 1 )
You can’t perform that action at this time.
0 commit comments