99from typing import Any , List , Literal , Optional , cast
1010
1111import torch
12+ import torch .nn .functional as F
1213
1314from tensorrt_llm ._torch .pyexecutor .make_decoding_batch_input_output import \
1415 MakeDecodingBatchInputOutput
@@ -891,13 +892,16 @@ def handle_logprobs(self, request: LlmRequest, state: SampleState, *,
891892 beam : int , count : int ):
892893 current_slice = slice (0 , count ), request .py_seq_slot , beam
893894 if request .py_return_log_probs :
894- assert state .host .log_probs is not None
895- log_probs = state .host .log_probs [request .py_seq_slot ][beam ][:count ]
896- current_tokens = state .host .new_tokens [current_slice ]
895+ topk_log_probs_vals = request .py_topk_logprobs_vals [:count ]
896+ topk_log_probs_indices = request .py_topk_logprobs_indices [:count ]
897897
898898 token_log_probs = [{
899- int (token ): Logprob (logprob = logprob , rank = 1 )
900- } for token , logprob in zip (current_tokens , log_probs .tolist ())]
899+ token : Logprob (logprob = logprob , rank = rank + 1 )
900+ for rank , (token , logprob ) in enumerate (
901+ zip (topk_token .tolist (), topk_logprob .tolist ()))
902+ }
903+ for topk_token , topk_logprob in zip (
904+ topk_log_probs_indices , topk_log_probs_vals )]
901905 assert beam == 0 , "The following call relies on beam_width to be 1 - hence the list with a single element"
902906 request .py_result .append_log_probs ([token_log_probs ])
903907
@@ -1162,13 +1166,8 @@ def log_probs_host(
11621166 self ,
11631167 scheduled_requests : ScheduledRequests ) -> Optional [torch .Tensor ]:
11641168 """Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103"""
1165- if any (req .py_return_log_probs
1166- for req in scheduled_requests .all_requests ()):
1167- return torch .empty (
1168- (self .max_num_sequences , self .MAX_BEAM_WIDTH , self .max_tokens ),
1169- device = "cpu" ,
1170- pin_memory = True )
1171- return None
1169+ return any (req .py_return_log_probs
1170+ for req in scheduled_requests .all_requests ())
11721171
11731172 @override
11741173 @torch .inference_mode ()
@@ -1198,8 +1197,7 @@ def sample_async(
11981197 sampler_event .record ()
11991198 return SampleState (scheduled_requests = scheduled_requests ,
12001199 device = SampleStateTensors (new_tokens = new_tokens ),
1201- host = SampleStateTensors (new_tokens = new_tokens_host ,
1202- log_probs = log_probs_host ),
1200+ host = SampleStateTensors (new_tokens = new_tokens_host ),
12031201 sampler_event = sampler_event )
12041202
12051203 @staticmethod
@@ -1308,12 +1306,22 @@ def _sample_batched_by_strategy(
13081306 model_outputs : dict [str , torch .Tensor ],
13091307 * ,
13101308 cuda_device : torch .device ,
1311- log_probs_host : torch . Tensor | None = None ,
1309+ log_probs_host : bool = False ,
13121310 req_num_steps : torch .Tensor ,
13131311 req_offsets : torch .Tensor ,
13141312 steps_dim_size : int ,
13151313 token_dtype : torch .dtype ,
13161314 ) -> _BatchedSamplingResult :
1315+ if log_probs_host :
1316+ assert logits_cuda .dim () == 2 , "logits should be 2D"
1317+ logprobs = F .log_softmax (logits_cuda .to ("cuda" ,
1318+ dtype = torch .float32 ),
1319+ dim = - 1 )
1320+ topk_vals , topk_indices = torch .topk (logprobs ,
1321+ k = max (req .py_num_logprobs
1322+ for req in requests ),
1323+ dim = - 1 )
1324+
13171325 requests_by_strategy = _group_requests_by_sampling_strategy (
13181326 requests , pin_memory = True )
13191327 generator_cuda = self .get_generator (cuda_device )
@@ -1357,12 +1365,20 @@ def _sample_batched_by_strategy(
13571365 # softmax_grp_indices: Indices of 'speculation_group_indices' entries requesting probs
13581366 # speculation_softmax_indices: Indices of 'softmax_grp_indices' entries corresponding
13591367 # to requests with draft logits.
1360- if log_probs_host is not None :
1368+ if log_probs_host :
13611369 softmax_req_indices = group_req_indices
13621370 softmax_grp_indices = torch .arange (len (group_req_indices ),
13631371 dtype = torch .int32 )
13641372 speculation_softmax_indices = torch .tensor (
13651373 speculation_group_indices , dtype = torch .int32 )
1374+ for req_id in group_req_indices :
1375+ req = requests [req_id ]
1376+ req .py_topk_logprobs_vals = topk_vals [
1377+ logits_cuda_indexer [req_id ], :req .py_num_logprobs ].to (
1378+ device = "cpu" , non_blocking = True )
1379+ req .py_topk_logprobs_indices = topk_indices [
1380+ logits_cuda_indexer [req_id ], :req .py_num_logprobs ].to (
1381+ device = "cpu" , non_blocking = True )
13661382 else :
13671383 speculation_group_indices_tensor = torch .tensor (
13681384 speculation_group_indices , dtype = torch .int32 )
@@ -1462,7 +1478,7 @@ def _unbatch_sampling_results(
14621478 new_tokens_cuda : torch .Tensor ,
14631479 req_num_steps : torch .Tensor ,
14641480 seq_slots : torch .Tensor ,
1465- log_probs_host : torch . Tensor | None = None ,
1481+ log_probs_host : bool = False ,
14661482 ) -> torch .Tensor :
14671483 beam = self .BEAM
14681484 assert beam == 0 , "beam_width != 1 not supported"
@@ -1479,17 +1495,6 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
14791495 # Assert destination tensor dimensions are canonically ordered ("row"-major); this
14801496 # matters for element ordering in the .view(...).scatter_(...) calls below.
14811497 assert _dims_canonically_ordered (new_tokens_cuda )
1482- assert log_probs_host is None or _dims_canonically_ordered (
1483- log_probs_host )
1484-
1485- # new_tokens_cuda indexed by
1486- # slice(0, steps), slot, beam
1487- # log_probs_host indexed by
1488- # slot, beam, slice(0, steps)
1489- # batch_... tensors indexed by slice(batch_req_index, batch_req_index + steps)
1490- #
1491- if log_probs_host is not None :
1492- assert new_tokens_cuda .size (0 ) == log_probs_host .size (- 2 )
14931498
14941499 # Construct index mapping from slice indices of computed tensors
14951500 # (packed request_idx and step dimensions) to linearized indices
@@ -1511,39 +1516,6 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
15111516 0 , batch_dest_indices_1d_cuda ,
15121517 batch_next_tokens_cuda_int )
15131518 new_tokens_host = new_tokens_cuda .to ("cpu" , non_blocking = True )
1514- # NB: In order to avoid a scatter_ on the host and the necessary D2H copy + synchronization,
1515- # the 'step' and 'seq_slot' dimensions are unpacked on GPU and later asynchronously
1516- # copied into the destination buffer. Note that this overwrites all 'step' and token slots for the
1517- # requests in 'requests' (passed to _process_requests). In fact, the current implementation
1518- # even overwrites the destination tensors completely (including slices corresponding to request
1519- # slots not present in 'requests', cf. 'FIXME' below).
1520- if log_probs_host is not None :
1521- # FIXME: If log_probs_host were indexed by request indices, rather than request slots, this
1522- # tensor could be packed densely along the request axis.
1523- log_probs_cuda = torch .empty_like (
1524- log_probs_host , device = batch_dest_indices_1d_cuda .device )
1525- # FIXME: Needs a separate indexer because tensor layout differs from new_tokens_cuda
1526- batch_dest_probs_cuda_indexer = _UnpackedStepIndexer (
1527- seq_slots = seq_slots [batch_req_indices ],
1528- num_steps = req_num_steps [batch_req_indices ],
1529- steps_dim_size = new_tokens_cuda .size (0 ),
1530- slots_dim_size = new_tokens_cuda .size (1 ),
1531- dim_order = _UnpackedStepIndexer .DimOrder .SLOT_MAJOR ,
1532- index_dtype = torch .int64 , # enforced by Tensor.scatter_
1533- )
1534- batch_dest_probs_indices_cuda = batch_dest_probs_cuda_indexer [:].to (
1535- batch_softmax_cuda .device , non_blocking = True )
1536- # NB: torch.arange is needed to enable "advanced indexing",
1537- # cf. https://numpy.org/devdocs/user/basics.indexing.html#integer-array-indexing
1538- batch_token_probs = batch_softmax_cuda [
1539- torch .arange (batch_softmax_cuda .size (0 ),
1540- device = batch_softmax_cuda .device ,
1541- dtype = torch .int32 ), batch_next_tokens_cuda_int ]
1542- log_probs_cuda [:, beam ,
1543- ...].view (- 1 , * log_probs_cuda .shape [3 :]).scatter_ (
1544- 0 , batch_dest_probs_indices_cuda ,
1545- torch .log (batch_token_probs ))
1546- log_probs_host .copy_ (log_probs_cuda , non_blocking = True )
15471519 # For requests with LlmRequest.py_draft_logits, return py_target_probs
15481520 for request , batch_softmax_index_cuda in py_draft_logits_indices :
15491521 request .py_target_probs = batch_softmax_cuda [
0 commit comments