File tree Expand file tree Collapse file tree 1 file changed +2
-5
lines changed Expand file tree Collapse file tree 1 file changed +2
-5
lines changed Original file line number Diff line number Diff line change @@ -164,12 +164,12 @@ def rejection_sample(
164
164
assert target_probs .shape == (num_tokens , vocab_size )
165
165
166
166
# Create output buffer.
167
- output_token_ids = torch .empty (
167
+ output_token_ids = torch .full (
168
168
(batch_size , max_spec_len + 1 ),
169
+ PLACEHOLDER_TOKEN_ID ,
169
170
dtype = torch .int32 , # Consistent with SamplerOutput.sampled_token_ids.
170
171
device = device ,
171
172
)
172
- output_token_ids .fill_ (PLACEHOLDER_TOKEN_ID )
173
173
174
174
if sampling_metadata .all_greedy :
175
175
is_greedy = None
@@ -186,7 +186,6 @@ def rejection_sample(
186
186
bonus_token_ids ,
187
187
is_greedy ,
188
188
max_spec_len ,
189
- num_warps = 1 ,
190
189
)
191
190
if sampling_metadata .all_greedy :
192
191
return output_token_ids
@@ -227,7 +226,6 @@ def rejection_sample(
227
226
max_spec_len ,
228
227
vocab_size ,
229
228
NO_DRAFT_PROBS = draft_probs is None ,
230
- num_warps = 1 ,
231
229
)
232
230
return output_token_ids
233
231
@@ -329,7 +327,6 @@ def expand_batch_to_tokens(
329
327
replace_from ,
330
328
replace_to ,
331
329
MAX_NUM_TOKENS = MAX_SPEC_LEN , # To avoid recompilation.
332
- num_warps = 1 ,
333
330
)
334
331
return expanded_x
335
332
You can’t perform that action at this time.
0 commit comments