@@ -202,12 +202,13 @@ class SampleStateWithMMResult:
202202class RequestGroupKey (Generic [GenericStrategyKeyType ]):
203203 strategy : GenericStrategyKeyType
204204 speculation_needs_probs : bool
205+ need_processed_logprobs : bool
205206
206207 def __iter__ (self ):
207- return iter ((self .strategy , self .speculation_needs_probs ))
208+ return iter ((self .strategy , self .speculation_needs_probs , self . need_processed_logprobs ))
208209
209210 def __len__ (self ):
210- return 2
211+ return 3
211212
212213
213214class RequestGroupValue (NamedTuple ):
@@ -338,13 +339,19 @@ def _group_requests_by_strategy_key(
338339 # process_draft_tokens.
339340 TorchSampler ._speculation_could_use_rejection_sampling (req , strategy )
340341 )
341- strategy_key = strategy_to_key (strategy , speculation_needs_probs )
342- group_dict_entry = group_dict [(strategy_key , speculation_needs_probs )]
342+ need_processed_logprobs = req .py_logprobs_mode == "processed"
343+ need_probs = speculation_needs_probs or need_processed_logprobs
344+ strategy_key = strategy_to_key (strategy , need_probs )
345+ group_dict_entry = group_dict [
346+ (strategy_key , speculation_needs_probs , need_processed_logprobs )
347+ ]
343348 group_dict_entry [0 ].append (req_index )
344349 group_dict_entry [1 ].append (strategy )
345350 return {
346351 RequestGroupKey (
347- strategy = group_key [0 ], speculation_needs_probs = group_key [1 ]
352+ strategy = group_key [0 ],
353+ speculation_needs_probs = group_key [1 ],
354+ need_processed_logprobs = group_key [2 ],
348355 ): RequestGroupValue (
349356 indices = torch .tensor (indices , pin_memory = pin_memory , dtype = torch .int32 ),
350357 strategies = strategies ,
@@ -374,6 +381,8 @@ class _BatchedSamplingResult:
374381 batch_req_indices : torch .Tensor
375382 # Next tokens for all requests:
376383 batch_next_tokens_cuda_int : torch .Tensor
384+ # Logits for all requests:
385+ batch_logits_cuda : torch .Tensor | None = None
377386
378387
379388# Helper class for _PackedStepIndexer and _UnpackedStepIndexer, facilitating the
@@ -942,34 +951,55 @@ def _convert_logprobs_tensor_to_list(
942951 self ,
943952 token_tensor : torch .Tensor ,
944953 logprobs_tensor : torch .Tensor ,
954+ sampled_log_probs_indices : torch .Tensor | None ,
955+ sampled_log_probs_vals : torch .Tensor | None ,
956+ sampled_log_probs_rank : torch .Tensor | None ,
945957 ) -> list [list [dict [int , Logprob ]]]:
946958 """Convert the logprobs tensor to a list of lists of dictionaries of Logprob objects
947959
948960 Logprobs storage expects logprobs as a list[list[dict[int, Logprob]]] object
949961
950962 args:
963+ token_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs
951964 logprobs_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs
965+ sampled_log_probs_indices: torch.Tensor | None. Shape: num_tokens
966+ sampled_log_probs_vals: torch.Tensor | None. Shape: num_tokens
967+ sampled_log_probs_rank: torch.Tensor | None. Shape: num_tokens
952968 output:
953969 list[list[dict[int, Logprob]]]. Shape: beam_width, num_tokens, dict with num_logprobs keys
954970 """
955971 assert token_tensor .dim () == 3 and logprobs_tensor .dim () == 3 , (
956972 f"Token and logprobs tensors must have 3 dimensions (beam_width, num_tokens, num_logprobs). \
957973 Got shapes (token_tensor) { token_tensor .shape } and (logprobs_tensor) { logprobs_tensor .shape } instead"
958974 )
959- return [
960- [
961- {
975+
976+ token_log_probs : list [list [dict [int , Logprob ]]] = []
977+ for beam_idx in range (token_tensor .shape [0 ]):
978+ beam_token_log_probs : list [dict [int , Logprob ]] = []
979+ for step_idx , (topk_token , topk_logprob ) in enumerate (
980+ zip (token_tensor [beam_idx ], logprobs_tensor [beam_idx ])
981+ ):
982+ logprobs = {
962983 token : Logprob (logprob = logprob , rank = rank + 1 )
963984 for rank , (token , logprob ) in enumerate (
964985 zip (topk_token .tolist (), topk_logprob .tolist ())
965986 )
966987 }
967- for topk_token , topk_logprob in zip (
968- token_tensor [beam_idx ], logprobs_tensor [beam_idx ]
969- )
970- ]
971- for beam_idx in range (token_tensor .shape [0 ])
972- ]
988+ if sampled_log_probs_indices is not None :
989+ assert beam_idx == 0 , (
990+ "beam search does not need to explicitly handle sampled log probs"
991+ )
992+ if sampled_log_probs_indices [step_idx ] not in logprobs :
993+ logprobs [sampled_log_probs_indices [step_idx ].item ()] = Logprob (
994+ logprob = sampled_log_probs_vals [step_idx ].item (),
995+ rank = max (
996+ token_tensor .shape [2 ] + 1 , sampled_log_probs_rank [step_idx ].item ()
997+ ),
998+ )
999+ beam_token_log_probs .append (logprobs )
1000+ token_log_probs .append (beam_token_log_probs )
1001+
1002+ return token_log_probs
9731003
9741004 def handle_logprobs (
9751005 self ,
@@ -986,6 +1016,10 @@ def handle_logprobs(
9861016 topk_log_probs_indices = self .store .new_tokens [0 , request .py_seq_slot ].view (
9871017 beam_width , count , - 1
9881018 )
1019+ sampled_log_probs_vals = None
1020+ sampled_log_probs_indices = None
1021+ # correct the rank to be 1-indexed
1022+ sampled_log_probs_rank = None
9891023 else :
9901024 assert beam_width == 1 , "beam width must be 1 for non-beam search"
9911025 topk_log_probs_vals = request .py_topk_logprobs_vals [: count * beam_width ].view (
@@ -994,9 +1028,17 @@ def handle_logprobs(
9941028 topk_log_probs_indices = request .py_topk_logprobs_indices [
9951029 : count * beam_width
9961030 ].view (beam_width , count , - 1 )
1031+ sampled_log_probs_vals = request .py_sampled_logprobs_vals [:count ]
1032+ sampled_log_probs_indices = request .py_sampled_logprobs_indices [:count ]
1033+ # correct the rank to be 1-indexed
1034+ sampled_log_probs_rank = request .py_sampled_logprobs_rank [:count ] + 1
9971035
9981036 token_log_probs = self ._convert_logprobs_tensor_to_list (
999- topk_log_probs_indices , topk_log_probs_vals
1037+ topk_log_probs_indices ,
1038+ topk_log_probs_vals ,
1039+ sampled_log_probs_indices ,
1040+ sampled_log_probs_vals ,
1041+ sampled_log_probs_rank ,
10001042 )
10011043 request .py_result .append_log_probs (token_log_probs )
10021044
@@ -1865,6 +1907,7 @@ def _sample_batched_by_strategy(
18651907 seq_slots : torch .Tensor ,
18661908 seq_lens : Optional [torch .Tensor ] = None ,
18671909 token_dtype : torch .dtype ,
1910+ return_log_probs : bool ,
18681911 ) -> _BatchedSamplingResult :
18691912 grouped_requests = _group_requests_by_strategy_key (
18701913 requests ,
@@ -1894,9 +1937,16 @@ def _sample_batched_by_strategy(
18941937 batch_next_tokens_cuda_int = torch .empty (
18951938 (logits_cuda .size (0 ), self .max_beam_width ), device = cuda_device , dtype = token_dtype
18961939 )
1940+ batch_logits_cuda = (
1941+ torch .empty (
1942+ (logits_cuda .size (0 ), logits_cuda .size (1 )), device = cuda_device , dtype = torch .float32
1943+ )
1944+ if return_log_probs
1945+ else None
1946+ )
18971947 batch_req_idx_offset_start = 0
18981948 batch_next_tokens_offset_start = 0
1899- for (strategy_key , speculation_needs_probs ), (
1949+ for (strategy_key , speculation_needs_probs , need_processed_logprobs ), (
19001950 group_req_indices ,
19011951 group_strategies ,
19021952 group_metadata ,
@@ -1943,7 +1993,7 @@ def _sample_batched_by_strategy(
19431993 group_strategies_per_step ,
19441994 group_logits_cuda ,
19451995 generator = generator_cuda ,
1946- return_probs = speculation_needs_probs ,
1996+ return_probs = speculation_needs_probs or need_processed_logprobs ,
19471997 group_logit_indices = logit_indices_for_sampler ,
19481998 group_metadata = group_metadata ,
19491999 )
@@ -1958,6 +2008,20 @@ def _sample_batched_by_strategy(
19582008 batch_next_tokens_offset_start :batch_next_tokens_offset_end
19592009 ].copy_ (group_next_tokens_cuda , non_blocking = True )
19602010
2011+ if return_log_probs :
2012+ if need_processed_logprobs :
2013+ # if softmax is 0, then the logit was masked out => set to -inf
2014+ group_tgt_logits_cuda = torch .where (
2015+ group_softmax_cuda != 0 , group_logits_cuda , float ("-inf" )
2016+ )
2017+ batch_logits_cuda [
2018+ batch_next_tokens_offset_start :batch_next_tokens_offset_end
2019+ ].copy_ (group_tgt_logits_cuda , non_blocking = True )
2020+ else :
2021+ batch_logits_cuda [
2022+ batch_next_tokens_offset_start :batch_next_tokens_offset_end
2023+ ].copy_ (group_logits_cuda , non_blocking = True )
2024+
19612025 # Set LlmRequest.py_target_probs
19622026 if speculation_needs_probs :
19632027 assert group_softmax_cuda is not None
@@ -1986,6 +2050,7 @@ def _sample_batched_by_strategy(
19862050 return _BatchedSamplingResult (
19872051 batch_req_indices = batch_req_indices ,
19882052 batch_next_tokens_cuda_int = batch_next_tokens_cuda_int ,
2053+ batch_logits_cuda = batch_logits_cuda ,
19892054 )
19902055
19912056 def _unbatch_sampling_results (
@@ -2385,6 +2450,63 @@ def request_stop_words(request: LlmRequest, new_tokens: torch.Tensor):
23852450 per_step [step , request_idx , beam_idx ] = True
23862451 return per_step
23872452
2453+ @nvtx_range ("_process_logprobs" )
2454+ def _process_logprobs (
2455+ self ,
2456+ batched_sampling_result : _BatchedSamplingResult ,
2457+ requests : list [LlmRequest ],
2458+ req_num_steps : torch .Tensor ,
2459+ ):
2460+ group_logprobs_cuda = F .log_softmax (batched_sampling_result .batch_logits_cuda , dim = - 1 )
2461+ all_req_indices = batched_sampling_result .batch_req_indices
2462+ group_next_tokens_cuda = batched_sampling_result .batch_next_tokens_cuda_int
2463+ group_req_indices = [
2464+ req_gid .item ()
2465+ for req_gid in all_req_indices
2466+ if requests [req_gid ].py_num_logprobs is not None
2467+ ]
2468+ topk_vals_cuda , topk_indices_cuda = torch .topk (
2469+ group_logprobs_cuda ,
2470+ k = max (requests [req_id ].py_num_logprobs for req_id in group_req_indices ),
2471+ dim = - 1 ,
2472+ )
2473+
2474+ sampled_vals_cuda = torch .gather (
2475+ group_logprobs_cuda , dim = - 1 , index = group_next_tokens_cuda .view (- 1 , 1 )
2476+ )
2477+ sampled_indices_cuda = group_next_tokens_cuda
2478+
2479+ # NB: we do not need group logprobs anymore, we can reuse the storage
2480+ # We only provide 0 based rank, it will be corrected to 1-indexed in handle logprobs
2481+ group_logprobs_cuda .greater_ (sampled_vals_cuda )
2482+ sampled_rank_cuda = group_logprobs_cuda .sum (dim = - 1 )
2483+
2484+ # Use a single D2H copy to reduce overheads
2485+ topk_vals = torch .empty_like (topk_vals_cuda , device = "cpu" , pin_memory = False )
2486+ topk_indices = torch .empty_like (topk_indices_cuda , device = "cpu" , pin_memory = False )
2487+ sampled_vals = torch .empty_like (sampled_vals_cuda , device = "cpu" , pin_memory = False )
2488+ sampled_indices = torch .empty_like (sampled_indices_cuda , device = "cpu" , pin_memory = False )
2489+ sampled_rank = torch .empty_like (sampled_rank_cuda , device = "cpu" , pin_memory = False )
2490+
2491+ topk_vals .copy_ (topk_vals_cuda , non_blocking = True )
2492+ topk_indices .copy_ (topk_indices_cuda , non_blocking = True )
2493+ sampled_vals .copy_ (sampled_vals_cuda , non_blocking = True )
2494+ sampled_indices .copy_ (sampled_indices_cuda , non_blocking = True )
2495+ sampled_rank .copy_ (sampled_rank_cuda , non_blocking = True )
2496+ current_offset = 0
2497+ for req_id , steps in zip (group_req_indices , req_num_steps [group_req_indices ].tolist ()):
2498+ req = requests [req_id ]
2499+ next_offset = current_offset + steps
2500+ # NB: Assigning views on memory which is being filled asynchronously
2501+ req .py_topk_logprobs_vals = topk_vals [current_offset :next_offset , : req .py_num_logprobs ]
2502+ req .py_sampled_logprobs_vals = sampled_vals [current_offset :next_offset ]
2503+ req .py_topk_logprobs_indices = topk_indices [
2504+ current_offset :next_offset , : req .py_num_logprobs
2505+ ]
2506+ req .py_sampled_logprobs_indices = sampled_indices [current_offset :next_offset ]
2507+ req .py_sampled_logprobs_rank = sampled_rank [current_offset :next_offset ]
2508+ current_offset = next_offset
2509+
23882510 @nvtx_range ("_process_requests" )
23892511 def _process_requests (
23902512 self ,
@@ -2454,55 +2576,6 @@ def _process_requests(
24542576 req_offsets = req_offsets ,
24552577 )
24562578
2457- # Handle top-k logprobs. This is done outside the sampling loop,
2458- # because the returned logprobs are specified to not reflect temperature scaling,
2459- # top-k/top-p masking, etc.
2460- if return_log_probs :
2461- assert logits_cuda .dim () == 2 , "logits should be 2D"
2462-
2463- logprobs_req_indices = [
2464- req_id for req_id , req in enumerate (requests ) if req .py_num_logprobs
2465- ]
2466- logprobs_logit_indices = logits_cuda_indexer [logprobs_req_indices ]
2467- logprobs_logit_indices_cuda = logprobs_logit_indices .to (
2468- device = logits_cuda .device , non_blocking = True
2469- )
2470- logprobs_cuda = F .log_softmax (
2471- logits_cuda [logprobs_logit_indices_cuda ].to (dtype = torch .float32 , non_blocking = True ),
2472- dim = - 1 ,
2473- )
2474- topk_vals_cuda , topk_indices_cuda = torch .topk (
2475- logprobs_cuda , k = max (req .py_num_logprobs for req in requests ), dim = - 1
2476- )
2477- # Use a single D2H copy to reduce overheads
2478- topk_vals = torch .empty_like (topk_vals_cuda , device = "cpu" , pin_memory = True )
2479- topk_indices = torch .empty_like (topk_indices_cuda , device = "cpu" , pin_memory = True )
2480- topk_vals .copy_ (topk_vals_cuda , non_blocking = True )
2481- topk_indices .copy_ (topk_indices_cuda , non_blocking = True )
2482- current_offset = 0
2483- for req_id , steps in zip (
2484- logprobs_req_indices , req_num_generated_tokens [logprobs_req_indices ].tolist ()
2485- ):
2486- req = requests [req_id ]
2487- next_offset = current_offset + steps
2488- # NB: Assigning views on memory which is being filled asynchronously
2489- req .py_topk_logprobs_vals = topk_vals [
2490- current_offset :next_offset , : req .py_num_logprobs
2491- ]
2492- req .py_topk_logprobs_indices = topk_indices [
2493- current_offset :next_offset , : req .py_num_logprobs
2494- ]
2495-
2496- # context requests do not have multiple input beams, but they need multiple output beams
2497- if req .is_context_init_state :
2498- req .py_topk_logprobs_vals = req .py_topk_logprobs_vals .expand (
2499- req .sampling_config .beam_width , - 1
2500- )
2501- req .py_topk_logprobs_indices = req .py_topk_logprobs_indices .expand (
2502- req .sampling_config .beam_width , - 1
2503- )
2504- current_offset = next_offset
2505-
25062579 # Perform sampling in batches
25072580 batched_sampling_result = self ._sample_batched_by_strategy (
25082581 logits_cuda ,
@@ -2515,8 +2588,12 @@ def _process_requests(
25152588 seq_lens = seq_lens ,
25162589 req_num_generated_tokens = req_num_generated_tokens ,
25172590 token_dtype = new_tokens_cuda .dtype ,
2591+ return_log_probs = return_log_probs ,
25182592 )
25192593
2594+ if return_log_probs :
2595+ self ._process_logprobs (batched_sampling_result , requests , req_num_steps )
2596+
25202597 # Fill results into output buffers
25212598 new_tokens_host = self ._unbatch_sampling_results (
25222599 batched_sampling_result ,
0 commit comments