@@ -986,18 +986,66 @@ def handle_logprobs(
986986 topk_log_probs_indices = self .store .new_tokens [0 , request .py_seq_slot ].view (
987987 beam_width , count , - 1
988988 )
989+ token_log_probs = self ._convert_logprobs_tensor_to_list (
990+ topk_log_probs_indices , topk_log_probs_vals
991+ )
989992 else :
990993 assert beam_width == 1 , "beam width must be 1 for non-beam search"
991- topk_log_probs_vals = request .py_topk_logprobs_vals [: count * beam_width ].view (
992- beam_width , count , - 1
993- )
994- topk_log_probs_indices = request .py_topk_logprobs_indices [
995- : count * beam_width
996- ].view (beam_width , count , - 1 )
994+
995+ sampled_tokens = request .get_tokens (0 )[- count :]
996+
997+ if request .py_num_logprobs == 0 :
998+ # Return only the sampled token's logprob
999+ # Compute at least top-1 to determine rank
1000+ if hasattr (request , 'py_sampled_logprobs' ) and request .py_sampled_logprobs is not None :
1001+ sampled_logprobs = request .py_sampled_logprobs [:count ]
1002+ topk_log_probs_vals = request .py_topk_logprobs_vals [:count ] # At least k=1
1003+ topk_log_probs_indices = request .py_topk_logprobs_indices [:count ]
1004+
1005+ token_log_probs = []
1006+ for step , (sampled_token , sampled_logprob , topk_tokens , topk_logprobs ) in enumerate (
1007+ zip (sampled_tokens , sampled_logprobs , topk_log_probs_indices , topk_log_probs_vals )
1008+ ):
1009+ topk_tokens_list = topk_tokens .tolist ()
1010+ if sampled_token in topk_tokens_list :
1011+ # Sampled token is in top-K, use its rank
1012+ rank = topk_tokens_list .index (sampled_token ) + 1
1013+ else :
1014+ # TODO: fix rank
1015+ rank = 2
9971016
998- token_log_probs = self ._convert_logprobs_tensor_to_list (
999- topk_log_probs_indices , topk_log_probs_vals
1000- )
1017+ step_dict = {sampled_token : Logprob (logprob = sampled_logprob .item (), rank = rank )}
1018+ token_log_probs .append (step_dict )
1019+ else :
1020+ raise ValueError ("py_sampled_logprobs not available when py_num_logprobs == 0" )
1021+ else :
1022+ # Return top-K logprobs + logprob of sampled token
1023+ sampled_logprobs = request .py_sampled_logprobs [:count ]
1024+ topk_log_probs_vals = request .py_topk_logprobs_vals [:count ]
1025+ topk_log_probs_indices = request .py_topk_logprobs_indices [:count ]
1026+
1027+ token_log_probs = []
1028+ for step , (sampled_token , sampled_logprob , topk_tokens , topk_logprobs ) in enumerate (
1029+ zip (sampled_tokens , sampled_logprobs , topk_log_probs_indices , topk_log_probs_vals )
1030+ ):
1031+ step_dict = {}
1032+ topk_tokens_list = topk_tokens .tolist ()
1033+ topk_logprobs_list = topk_logprobs .tolist ()
1034+
1035+ for rank_idx , (token , logprob ) in enumerate (zip (topk_tokens_list , topk_logprobs_list ), start = 1 ):
1036+ step_dict [token ] = Logprob (logprob = logprob , rank = rank_idx )
1037+
1038+ if sampled_token not in step_dict :
1039+ # TODO: fix rank
1040+ step_dict [sampled_token ] = Logprob (
1041+ logprob = sampled_logprob .item (),
1042+ rank = len (topk_tokens_list ) + 1
1043+ )
1044+ token_log_probs .append (step_dict )
1045+
1046+ # Wrap in list for non-beam search (beam_width=1)
1047+ token_log_probs = [token_log_probs ]
1048+
10011049 request .py_result .append_log_probs (token_log_probs )
10021050
10031051 def finish_if_reason (
@@ -2461,47 +2509,55 @@ def _process_requests(
24612509 assert logits_cuda .dim () == 2 , "logits should be 2D"
24622510
24632511 logprobs_req_indices = [
2464- req_id for req_id , req in enumerate (requests ) if req .py_num_logprobs
2512+ req_id for req_id , req in enumerate (requests ) if req .py_num_logprobs is not None
24652513 ]
2466- logprobs_logit_indices = logits_cuda_indexer [logprobs_req_indices ]
2467- logprobs_logit_indices_cuda = logprobs_logit_indices .to (
2468- device = logits_cuda .device , non_blocking = True
2469- )
2470- logprobs_cuda = F .log_softmax (
2471- logits_cuda [logprobs_logit_indices_cuda ].to (dtype = torch .float32 , non_blocking = True ),
2472- dim = - 1 ,
2473- )
2474- topk_vals_cuda , topk_indices_cuda = torch .topk (
2475- logprobs_cuda , k = max (req .py_num_logprobs for req in requests ), dim = - 1
2476- )
2477- # Use a single D2H copy to reduce overheads
2478- topk_vals = torch .empty_like (topk_vals_cuda , device = "cpu" , pin_memory = True )
2479- topk_indices = torch .empty_like (topk_indices_cuda , device = "cpu" , pin_memory = True )
2480- topk_vals .copy_ (topk_vals_cuda , non_blocking = True )
2481- topk_indices .copy_ (topk_indices_cuda , non_blocking = True )
2482- current_offset = 0
2483- for req_id , steps in zip (
2484- logprobs_req_indices , req_num_generated_tokens [logprobs_req_indices ].tolist ()
2485- ):
2486- req = requests [req_id ]
2487- next_offset = current_offset + steps
2488- # NB: Assigning views on memory which is being filled asynchronously
2489- req .py_topk_logprobs_vals = topk_vals [
2490- current_offset :next_offset , : req .py_num_logprobs
2491- ]
2492- req .py_topk_logprobs_indices = topk_indices [
2493- current_offset :next_offset , : req .py_num_logprobs
2494- ]
24952514
2496- # context requests do not have multiple input beams, but they need multiple output beams
2497- if req .is_context_init_state :
2498- req .py_topk_logprobs_vals = req .py_topk_logprobs_vals .expand (
2499- req .sampling_config .beam_width , - 1
2500- )
2501- req .py_topk_logprobs_indices = req .py_topk_logprobs_indices .expand (
2502- req .sampling_config .beam_width , - 1
2503- )
2504- current_offset = next_offset
2515+ if logprobs_req_indices :
2516+ logprobs_logit_indices = logits_cuda_indexer [logprobs_req_indices ]
2517+ logprobs_logit_indices_cuda = logprobs_logit_indices .to (
2518+ device = logits_cuda .device , non_blocking = True
2519+ )
2520+ logprobs_cuda = F .log_softmax (
2521+ logits_cuda [logprobs_logit_indices_cuda ].to (dtype = torch .float32 , non_blocking = True ),
2522+ dim = - 1 ,
2523+ )
2524+
2525+ max_k = max (max (1 , req .py_num_logprobs ) for req in requests if req .py_num_logprobs is not None )
2526+ topk_vals_cuda , topk_indices_cuda = torch .topk (
2527+ logprobs_cuda ,
2528+ k = max_k ,
2529+ dim = - 1
2530+ )
2531+ # Use a single D2H copy to reduce overheads
2532+ topk_vals = torch .empty_like (topk_vals_cuda , device = "cpu" , pin_memory = True )
2533+ topk_indices = torch .empty_like (topk_indices_cuda , device = "cpu" , pin_memory = True )
2534+ topk_vals .copy_ (topk_vals_cuda , non_blocking = True )
2535+ topk_indices .copy_ (topk_indices_cuda , non_blocking = True )
2536+ current_offset = 0
2537+ for req_id , steps in zip (
2538+ logprobs_req_indices , req_num_generated_tokens [logprobs_req_indices ].tolist ()
2539+ ):
2540+ req = requests [req_id ]
2541+ next_offset = current_offset + steps
2542+ # Store at least k=1 for all requests (including logprobs=0) to compute ranks
2543+ k_for_req = max (1 , req .py_num_logprobs )
2544+ # NB: Assigning views on memory which is being filled asynchronously
2545+ req .py_topk_logprobs_vals = topk_vals [
2546+ current_offset :next_offset , : k_for_req
2547+ ]
2548+ req .py_topk_logprobs_indices = topk_indices [
2549+ current_offset :next_offset , : k_for_req
2550+ ]
2551+
2552+ # context requests do not have multiple input beams, but they need multiple output beams
2553+ if req .is_context_init_state :
2554+ req .py_topk_logprobs_vals = req .py_topk_logprobs_vals .expand (
2555+ req .sampling_config .beam_width , - 1
2556+ )
2557+ req .py_topk_logprobs_indices = req .py_topk_logprobs_indices .expand (
2558+ req .sampling_config .beam_width , - 1
2559+ )
2560+ current_offset = next_offset
25052561
25062562 # Perform sampling in batches
25072563 batched_sampling_result = self ._sample_batched_by_strategy (
@@ -2517,6 +2573,52 @@ def _process_requests(
25172573 token_dtype = new_tokens_cuda .dtype ,
25182574 )
25192575
2576+ if return_log_probs and logprobs_req_indices :
2577+ sampled_tokens_cuda = batched_sampling_result .batch_next_tokens_cuda_int
2578+ batch_req_indices = batched_sampling_result .batch_req_indices
2579+ logprobs_req_set = set (logprobs_req_indices )
2580+ sampled_logprobs_list = []
2581+
2582+ # Build offsets for the GROUPED order
2583+ grouped_num_steps = req_num_steps [batch_req_indices ]
2584+ grouped_offsets = torch .cat ([
2585+ torch .zeros ((1 ,), dtype = torch .int32 , pin_memory = True ),
2586+ grouped_num_steps .cumsum (dim = 0 )[:- 1 ]
2587+ ])
2588+
2589+ # Reverse mapping: original_req_id → position in grouped result
2590+ req_to_grouped_pos = {
2591+ orig_id .item (): grouped_pos
2592+ for grouped_pos , orig_id in enumerate (batch_req_indices )
2593+ }
2594+
2595+ for req_id in range (len (requests )):
2596+ if req_id in logprobs_req_set :
2597+ logprobs_idx = logprobs_req_indices .index (req_id )
2598+
2599+ if logprobs_idx == 0 :
2600+ start_offset = 0
2601+ else :
2602+ start_offset = sum (req_num_steps [logprobs_req_indices [:logprobs_idx ]].tolist ())
2603+
2604+ num_steps_this_req = req_num_steps [req_id ].item ()
2605+ end_offset = start_offset + num_steps_this_req
2606+
2607+ grouped_pos = req_to_grouped_pos [req_id ]
2608+ grouped_start = grouped_offsets [grouped_pos ].item ()
2609+ grouped_end = grouped_start + grouped_num_steps [grouped_pos ].item ()
2610+
2611+ sampled_tokens_this_req = sampled_tokens_cuda [grouped_start :grouped_end ]
2612+
2613+ step_indices = torch .arange (start_offset , end_offset , device = logprobs_cuda .device )
2614+ sampled_logprobs_cuda = logprobs_cuda [step_indices , sampled_tokens_this_req .long ()]
2615+
2616+ sampled_logprobs_cpu = sampled_logprobs_cuda .to (device = "cpu" , non_blocking = True )
2617+ sampled_logprobs_list .append ((req_id , sampled_logprobs_cpu ))
2618+
2619+ for req_id , sampled_logprobs in sampled_logprobs_list :
2620+ requests [req_id ].py_sampled_logprobs = sampled_logprobs
2621+
25202622 # Fill results into output buffers
25212623 new_tokens_host = self ._unbatch_sampling_results (
25222624 batched_sampling_result ,
0 commit comments