@@ -1996,7 +1996,8 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None:
1996
1996
# Should be called after attention metadata creation. This just pads
1997
1997
# the second ubatch slice out to the total number of tokens
1998
1998
# (num_tokens + padding)
1999
- def pad_out_ubatch_slice (self , ubatch_slices : UBatchSlices , num_total_tokens : int ):
1999
+ @staticmethod
2000
+ def pad_out_ubatch_slice (ubatch_slices : UBatchSlices , num_total_tokens : int ):
2000
2001
padded_second_ubatch_slice = slice (
2001
2002
ubatch_slices [1 ].token_slice .start , num_total_tokens
2002
2003
)
@@ -2085,12 +2086,13 @@ def _preprocess(
2085
2086
dict [str , Any ],
2086
2087
]:
2087
2088
num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
2089
+ is_first_rank = get_pp_group ().is_first_rank
2088
2090
2089
2091
# _prepare_inputs may reorder the batch, so we must gather multi
2090
2092
# modal outputs after that to ensure the correct order
2091
2093
if (
2092
2094
self .supports_mm_inputs
2093
- and get_pp_group (). is_first_rank
2095
+ and is_first_rank
2094
2096
and not self .model_config .is_encoder_decoder
2095
2097
):
2096
2098
# Run the multimodal encoder if any.
@@ -2115,7 +2117,7 @@ def _preprocess(
2115
2117
** self ._init_model_kwargs (num_scheduled_tokens ),
2116
2118
** self ._extract_mm_kwargs (scheduler_output ),
2117
2119
}
2118
- elif self .enable_prompt_embeds and get_pp_group (). is_first_rank :
2120
+ elif self .enable_prompt_embeds and is_first_rank :
2119
2121
# Get the input embeddings for the tokens that are not input embeds,
2120
2122
# then put them into the appropriate positions.
2121
2123
# TODO(qthequartermasterman): Since even when prompt embeds are
@@ -2155,7 +2157,7 @@ def _preprocess(
2155
2157
else :
2156
2158
positions = self .positions .gpu [:num_input_tokens ]
2157
2159
2158
- if get_pp_group (). is_first_rank :
2160
+ if is_first_rank :
2159
2161
intermediate_tensors = None
2160
2162
else :
2161
2163
intermediate_tensors = self .sync_and_slice_intermediate_tensors (
@@ -2186,38 +2188,37 @@ def _sample(
2186
2188
# Sample the next token and get logprobs if needed.
2187
2189
sampling_metadata = self .input_batch .sampling_metadata
2188
2190
if spec_decode_metadata is None :
2189
- sampler_output = self .sampler (
2191
+ return self .sampler (
2190
2192
logits = logits ,
2191
2193
sampling_metadata = sampling_metadata ,
2192
2194
)
2193
- else :
2194
- # When indexing with a tensor (bonus_logits_indices), PyTorch
2195
- # creates a new tensor with separate storage from the original
2196
- # logits tensor. This means any in-place operations on bonus_logits
2197
- # won't affect the original logits tensor.
2198
- assert logits is not None
2199
- bonus_logits = logits [spec_decode_metadata .bonus_logits_indices ]
2200
- sampler_output = self .sampler (
2201
- logits = bonus_logits ,
2202
- sampling_metadata = sampling_metadata ,
2203
- predict_bonus_token = True ,
2204
- )
2205
- bonus_token_ids = sampler_output .sampled_token_ids
2206
-
2207
- # Just like `bonus_logits`, `target_logits` is a new tensor with
2208
- # separate storage from the original `logits` tensor. Therefore,
2209
- # it is safe to update `target_logits` in place.
2210
- target_logits = logits [spec_decode_metadata .target_logits_indices ]
2211
- output_token_ids = self .rejection_sampler (
2212
- spec_decode_metadata ,
2213
- None , # draft_probs
2214
- target_logits ,
2215
- bonus_token_ids ,
2216
- sampling_metadata ,
2217
- )
2218
- sampler_output .sampled_token_ids = output_token_ids
2219
- self ._update_states_after_model_execute (output_token_ids )
2220
2195
2196
+ # When indexing with a tensor (bonus_logits_indices), PyTorch
2197
+ # creates a new tensor with separate storage from the original
2198
+ # logits tensor. This means any in-place operations on bonus_logits
2199
+ # won't affect the original logits tensor.
2200
+ assert logits is not None
2201
+ bonus_logits = logits [spec_decode_metadata .bonus_logits_indices ]
2202
+ sampler_output = self .sampler (
2203
+ logits = bonus_logits ,
2204
+ sampling_metadata = sampling_metadata ,
2205
+ predict_bonus_token = True ,
2206
+ )
2207
+ bonus_token_ids = sampler_output .sampled_token_ids
2208
+
2209
+ # Just like `bonus_logits`, `target_logits` is a new tensor with
2210
+ # separate storage from the original `logits` tensor. Therefore,
2211
+ # it is safe to update `target_logits` in place.
2212
+ target_logits = logits [spec_decode_metadata .target_logits_indices ]
2213
+ output_token_ids = self .rejection_sampler (
2214
+ spec_decode_metadata ,
2215
+ None , # draft_probs
2216
+ target_logits ,
2217
+ bonus_token_ids ,
2218
+ sampling_metadata ,
2219
+ )
2220
+ sampler_output .sampled_token_ids = output_token_ids
2221
+ self ._update_states_after_model_execute (output_token_ids )
2221
2222
return sampler_output
2222
2223
2223
2224
def _bookkeeping_sync (
@@ -3741,7 +3742,7 @@ def freeze_gc():
3741
3742
decode_cudagraph_batch_sizes = [
3742
3743
x
3743
3744
for x in self .cudagraph_batch_sizes
3744
- if x <= max_num_tokens and x >= self .uniform_decode_query_len
3745
+ if max_num_tokens >= x >= self .uniform_decode_query_len
3745
3746
]
3746
3747
compilation_cases_decode = list (reversed (decode_cudagraph_batch_sizes ))
3747
3748
self ._capture_cudagraphs (
0 commit comments