Skip to content

Commit ddcbc2f

Browse files
authored
[Misc] Misc code simplifications (vllm-project#26450)
Signed-off-by: Nick Hill <[email protected]>
1 parent a83ff27 commit ddcbc2f

File tree

6 files changed

+79
-90
lines changed

6 files changed

+79
-90
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,7 @@ def _update_requests_with_invalid_blocks(
14741474

14751475
affected_req_ids.add(request.request_id)
14761476

1477-
return (affected_req_ids, total_affected_tokens)
1477+
return affected_req_ids, total_affected_tokens
14781478

14791479
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
14801480
total_requests_to_reschedule = 0

vllm/v1/core/sched/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def check_stop(
5959
sampling_params = request.sampling_params
6060
assert sampling_params is not None
6161

62-
min_tokens = sampling_params.min_tokens
63-
if request.num_output_tokens < min_tokens:
62+
if request.num_output_tokens < sampling_params.min_tokens:
6463
return False
6564

6665
last_token_id = request.output_token_ids[-1]

vllm/v1/sample/rejection_sampler.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,22 +147,20 @@ def apply_logits_processors(
147147
sampling_metadata: SamplingMetadata,
148148
metadata: SpecDecodeMetadata,
149149
) -> torch.Tensor:
150+
has_penalties = not sampling_metadata.no_penalties
150151
any_penalties_or_bad_words = (
151-
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
152+
sampling_metadata.bad_words_token_ids or has_penalties
152153
)
153154

154155
output_token_ids = sampling_metadata.output_token_ids
155156
if any_penalties_or_bad_words:
156157
output_token_ids = self._combine_outputs_with_spec_tokens(
157-
sampling_metadata.output_token_ids,
158+
output_token_ids,
158159
sampling_metadata.spec_token_ids,
159160
)
160161

161162
# Calculate indices of target logits.
162-
if (
163-
sampling_metadata.allowed_token_ids_mask is not None
164-
or not sampling_metadata.no_penalties
165-
):
163+
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
166164
num_requests = len(sampling_metadata.output_token_ids)
167165
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
168166
original_indices = torch.arange(num_requests, device="cpu")
@@ -180,18 +178,15 @@ def apply_logits_processors(
180178
logits.masked_fill_(token_mask, float("-inf"))
181179

182180
# Apply bad words exclusion.
183-
if sampling_metadata.bad_words_token_ids:
181+
if bad_words_token_ids := sampling_metadata.bad_words_token_ids:
184182
apply_bad_words_with_drafts(
185-
logits,
186-
sampling_metadata.bad_words_token_ids,
187-
output_token_ids,
188-
metadata.num_draft_tokens,
183+
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
189184
)
190185

191186
return logits
192187

188+
@staticmethod
193189
def apply_penalties(
194-
self,
195190
logits: torch.Tensor,
196191
sampling_metadata: SamplingMetadata,
197192
metadata: SpecDecodeMetadata,
@@ -218,8 +213,8 @@ def apply_penalties(
218213
)
219214
return logits
220215

216+
@staticmethod
221217
def _combine_outputs_with_spec_tokens(
222-
self,
223218
output_token_ids: list[list[int]],
224219
spec_token_ids: Optional[list[list[int]]] = None,
225220
) -> list[list[int]]:

vllm/v1/sample/sampler.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def forward(
120120
)
121121
return sampler_output
122122

123+
@staticmethod
123124
def apply_temperature(
124-
self,
125125
logits: torch.Tensor,
126126
temp: torch.Tensor,
127127
all_random: bool,
@@ -132,7 +132,8 @@ def apply_temperature(
132132
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
133133
return logits.div_(temp.unsqueeze(dim=1))
134134

135-
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
135+
@staticmethod
136+
def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
136137
return logits.argmax(dim=-1).view(-1)
137138

138139
def sample(
@@ -191,11 +192,12 @@ def sample(
191192
)
192193
return sampled, processed_logprobs
193194

194-
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
195+
@staticmethod
196+
def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
195197
return logits.log_softmax(dim=-1, dtype=torch.float32)
196198

199+
@staticmethod
197200
def gather_logprobs(
198-
self,
199201
logprobs: torch.Tensor,
200202
num_logprobs: int,
201203
token_ids: torch.Tensor,
@@ -238,8 +240,8 @@ def gather_logprobs(
238240

239241
return LogprobsTensors(indices, logprobs, token_ranks)
240242

243+
@staticmethod
241244
def _combine_outputs_with_spec_tokens(
242-
self,
243245
output_token_ids: list[list[int]],
244246
spec_token_ids: Optional[list[list[int]]] = None,
245247
) -> list[list[int]]:
@@ -257,16 +259,17 @@ def apply_logits_processors(
257259
sampling_metadata: SamplingMetadata,
258260
predict_bonus_token: bool,
259261
) -> torch.Tensor:
262+
bad_words_token_ids = sampling_metadata.bad_words_token_ids
260263
any_penalties_or_bad_words = (
261-
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
264+
bool(bad_words_token_ids) or not sampling_metadata.no_penalties
262265
)
263266

264267
output_token_ids = sampling_metadata.output_token_ids
265268
if predict_bonus_token and any_penalties_or_bad_words:
266269
# Combine base outputs with spec tokens when speculative decoding
267270
# is enabled.
268271
output_token_ids = self._combine_outputs_with_spec_tokens(
269-
sampling_metadata.output_token_ids,
272+
output_token_ids,
270273
sampling_metadata.spec_token_ids,
271274
)
272275

@@ -275,14 +278,8 @@ def apply_logits_processors(
275278
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
276279

277280
# Apply bad words exclusion.
278-
if sampling_metadata.bad_words_token_ids:
279-
apply_bad_words(
280-
logits,
281-
sampling_metadata.bad_words_token_ids,
282-
output_token_ids
283-
if output_token_ids is not None
284-
else sampling_metadata.output_token_ids,
285-
)
281+
if bad_words_token_ids:
282+
apply_bad_words(logits, bad_words_token_ids, output_token_ids)
286283

287284
# Apply logits processors which can impact greedy sampling.
288285
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
@@ -292,22 +289,21 @@ def apply_logits_processors(
292289
logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
293290
return logits
294291

292+
@staticmethod
295293
def apply_penalties(
296-
self,
297294
logits: torch.Tensor,
298295
sampling_metadata: SamplingMetadata,
299-
output_token_ids: Optional[list[list[int]]] = None,
296+
output_token_ids: list[list[int]],
300297
) -> torch.Tensor:
301-
if not sampling_metadata.no_penalties:
302-
assert sampling_metadata.prompt_token_ids is not None
303-
logits = apply_all_penalties(
304-
logits,
305-
sampling_metadata.prompt_token_ids,
306-
sampling_metadata.presence_penalties,
307-
sampling_metadata.frequency_penalties,
308-
sampling_metadata.repetition_penalties,
309-
output_token_ids
310-
if output_token_ids is not None
311-
else sampling_metadata.output_token_ids,
312-
)
313-
return logits
298+
if sampling_metadata.no_penalties:
299+
return logits
300+
301+
assert sampling_metadata.prompt_token_ids is not None
302+
return apply_all_penalties(
303+
logits,
304+
sampling_metadata.prompt_token_ids,
305+
sampling_metadata.presence_penalties,
306+
sampling_metadata.frequency_penalties,
307+
sampling_metadata.repetition_penalties,
308+
output_token_ids,
309+
)

vllm/v1/worker/gpu_input_batch.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,9 @@ def get_token_id(self, idx: int) -> int:
6262
"provided via prompt_embeds, and its ID is unknown."
6363
)
6464
return self.prompt_token_ids[idx]
65-
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
65+
if idx - self.num_prompt_tokens < len(self.output_token_ids):
6666
return self.output_token_ids[idx - self.num_prompt_tokens]
67-
else:
68-
return -1
67+
return -1
6968

7069

7170
class InputBatch:
@@ -770,14 +769,13 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
770769
not self.no_penalties
771770
or self.logits_processing_needs_token_ids[:num_reqs].any()
772771
)
773-
if needs_prompt_token_ids:
774-
# The prompt tokens are used only for applying penalties or
775-
# step pooling during the sampling/pooling process.
776-
# Hence copy these tensors only when there are requests which
777-
# need penalties/step_pooler to be applied.
778-
prompt_token_ids = self._make_prompt_token_ids_tensor()
779-
else:
780-
prompt_token_ids = None
772+
# The prompt tokens are used only for applying penalties or
773+
# step pooling during the sampling/pooling process.
774+
# Hence copy these tensors only when there are requests which
775+
# need penalties/step_pooler to be applied.
776+
prompt_token_ids = (
777+
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
778+
)
781779

782780
allowed_token_ids_mask: Optional[torch.Tensor] = None
783781
if not self.no_allowed_token_ids:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,7 +1996,8 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None:
19961996
# Should be called after attention metadata creation. This just pads
19971997
# the second ubatch slice out to the total number of tokens
19981998
# (num_tokens + padding)
1999-
def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, num_total_tokens: int):
1999+
@staticmethod
2000+
def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
20002001
padded_second_ubatch_slice = slice(
20012002
ubatch_slices[1].token_slice.start, num_total_tokens
20022003
)
@@ -2085,12 +2086,13 @@ def _preprocess(
20852086
dict[str, Any],
20862087
]:
20872088
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
2089+
is_first_rank = get_pp_group().is_first_rank
20882090

20892091
# _prepare_inputs may reorder the batch, so we must gather multi
20902092
# modal outputs after that to ensure the correct order
20912093
if (
20922094
self.supports_mm_inputs
2093-
and get_pp_group().is_first_rank
2095+
and is_first_rank
20942096
and not self.model_config.is_encoder_decoder
20952097
):
20962098
# Run the multimodal encoder if any.
@@ -2115,7 +2117,7 @@ def _preprocess(
21152117
**self._init_model_kwargs(num_scheduled_tokens),
21162118
**self._extract_mm_kwargs(scheduler_output),
21172119
}
2118-
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
2120+
elif self.enable_prompt_embeds and is_first_rank:
21192121
# Get the input embeddings for the tokens that are not input embeds,
21202122
# then put them into the appropriate positions.
21212123
# TODO(qthequartermasterman): Since even when prompt embeds are
@@ -2155,7 +2157,7 @@ def _preprocess(
21552157
else:
21562158
positions = self.positions.gpu[:num_input_tokens]
21572159

2158-
if get_pp_group().is_first_rank:
2160+
if is_first_rank:
21592161
intermediate_tensors = None
21602162
else:
21612163
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
@@ -2186,38 +2188,37 @@ def _sample(
21862188
# Sample the next token and get logprobs if needed.
21872189
sampling_metadata = self.input_batch.sampling_metadata
21882190
if spec_decode_metadata is None:
2189-
sampler_output = self.sampler(
2191+
return self.sampler(
21902192
logits=logits,
21912193
sampling_metadata=sampling_metadata,
21922194
)
2193-
else:
2194-
# When indexing with a tensor (bonus_logits_indices), PyTorch
2195-
# creates a new tensor with separate storage from the original
2196-
# logits tensor. This means any in-place operations on bonus_logits
2197-
# won't affect the original logits tensor.
2198-
assert logits is not None
2199-
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
2200-
sampler_output = self.sampler(
2201-
logits=bonus_logits,
2202-
sampling_metadata=sampling_metadata,
2203-
predict_bonus_token=True,
2204-
)
2205-
bonus_token_ids = sampler_output.sampled_token_ids
2206-
2207-
# Just like `bonus_logits`, `target_logits` is a new tensor with
2208-
# separate storage from the original `logits` tensor. Therefore,
2209-
# it is safe to update `target_logits` in place.
2210-
target_logits = logits[spec_decode_metadata.target_logits_indices]
2211-
output_token_ids = self.rejection_sampler(
2212-
spec_decode_metadata,
2213-
None, # draft_probs
2214-
target_logits,
2215-
bonus_token_ids,
2216-
sampling_metadata,
2217-
)
2218-
sampler_output.sampled_token_ids = output_token_ids
2219-
self._update_states_after_model_execute(output_token_ids)
22202195

2196+
# When indexing with a tensor (bonus_logits_indices), PyTorch
2197+
# creates a new tensor with separate storage from the original
2198+
# logits tensor. This means any in-place operations on bonus_logits
2199+
# won't affect the original logits tensor.
2200+
assert logits is not None
2201+
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
2202+
sampler_output = self.sampler(
2203+
logits=bonus_logits,
2204+
sampling_metadata=sampling_metadata,
2205+
predict_bonus_token=True,
2206+
)
2207+
bonus_token_ids = sampler_output.sampled_token_ids
2208+
2209+
# Just like `bonus_logits`, `target_logits` is a new tensor with
2210+
# separate storage from the original `logits` tensor. Therefore,
2211+
# it is safe to update `target_logits` in place.
2212+
target_logits = logits[spec_decode_metadata.target_logits_indices]
2213+
output_token_ids = self.rejection_sampler(
2214+
spec_decode_metadata,
2215+
None, # draft_probs
2216+
target_logits,
2217+
bonus_token_ids,
2218+
sampling_metadata,
2219+
)
2220+
sampler_output.sampled_token_ids = output_token_ids
2221+
self._update_states_after_model_execute(output_token_ids)
22212222
return sampler_output
22222223

22232224
def _bookkeeping_sync(
@@ -3741,7 +3742,7 @@ def freeze_gc():
37413742
decode_cudagraph_batch_sizes = [
37423743
x
37433744
for x in self.cudagraph_batch_sizes
3744-
if x <= max_num_tokens and x >= self.uniform_decode_query_len
3745+
if max_num_tokens >= x >= self.uniform_decode_query_len
37453746
]
37463747
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes))
37473748
self._capture_cudagraphs(

0 commit comments

Comments
 (0)