Skip to content

Commit 0879736

Browse files
authored
[Perf] Remove hardcoded num_warps=1 (vllm-project#26183)
Signed-off-by: Corey Lowman <[email protected]>
1 parent a269173 commit 0879736

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

vllm/v1/sample/rejection_sampler.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,12 @@ def rejection_sample(
164164
assert target_probs.shape == (num_tokens, vocab_size)
165165

166166
# Create output buffer.
167-
output_token_ids = torch.empty(
167+
output_token_ids = torch.full(
168168
(batch_size, max_spec_len + 1),
169+
PLACEHOLDER_TOKEN_ID,
169170
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
170171
device=device,
171172
)
172-
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
173173

174174
if sampling_metadata.all_greedy:
175175
is_greedy = None
@@ -186,7 +186,6 @@ def rejection_sample(
186186
bonus_token_ids,
187187
is_greedy,
188188
max_spec_len,
189-
num_warps=1,
190189
)
191190
if sampling_metadata.all_greedy:
192191
return output_token_ids
@@ -227,7 +226,6 @@ def rejection_sample(
227226
max_spec_len,
228227
vocab_size,
229228
NO_DRAFT_PROBS=draft_probs is None,
230-
num_warps=1,
231229
)
232230
return output_token_ids
233231

@@ -329,7 +327,6 @@ def expand_batch_to_tokens(
329327
replace_from,
330328
replace_to,
331329
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
332-
num_warps=1,
333330
)
334331
return expanded_x
335332

0 commit comments

Comments
 (0)