@@ -396,6 +396,7 @@ def _group_requests_by_sampling_strategy(
396396 requests : Iterable [LlmRequest ],
397397 * ,
398398 pin_memory : bool = False ) -> dict [Strategy , torch .Tensor ]:
399+ # NB: Client code relies on request indices in returned torch.Tensor being sorted.
399400 strategy_dict : dict [Strategy , list [int ]] = defaultdict (list )
400401 for req_index , req in enumerate (requests ):
401402 strategy_dict [_request_strategy (req )].append (req_index )
@@ -1372,12 +1373,20 @@ def _sample_batched_by_strategy(
13721373 len (speculation_group_indices ), dtype = torch .int32 )
13731374
13741375 group_logits_cuda_indices = logits_cuda_indexer [group_req_indices ]
1375- if group_logits_cuda_indices .numel () != logits_cuda .size (0 ):
1376+ # NB: Assuming that group_req_indices are sorted
1377+ group_req_1st_index , group_req_last_index = group_req_indices [
1378+ 0 ], group_req_indices [- 1 ]
1379+ if group_req_last_index - group_req_1st_index + 1 == len (
1380+ group_req_indices ):
1381+ # Avoid data movement if indices are contiguous
1382+ group_logits_cuda = logits_cuda [
1383+ req_offsets [group_req_1st_index ]:(
1384+ req_offsets [group_req_last_index ] +
1385+ req_num_steps [group_req_last_index ])]
1386+ else :
13761387 group_logits_cuda_indices_cuda = group_logits_cuda_indices .to (
13771388 device = logits_cuda .device , non_blocking = True )
13781389 group_logits_cuda = logits_cuda [group_logits_cuda_indices_cuda ]
1379- else :
1380- group_logits_cuda = logits_cuda
13811390
13821391 # Indexer for accessing tokens in 'group_logits_cuda' (and 'group_next_tokens_cuda')
13831392 # corresponding to the requests in 'group_req_indices'.
0 commit comments