@@ -991,20 +991,33 @@ def handle_logprobs(
991991 )
992992 else :
993993 assert beam_width == 1 , "beam width must be 1 for non-beam search"
994-
994+
995995 sampled_tokens = request .get_tokens (0 )[- count :]
996996
997997 if request .py_num_logprobs == 0 :
998998 # Return only the sampled token's logprob
999999 # Compute at least top-1 to determine rank
1000- if hasattr (request , 'py_sampled_logprobs' ) and request .py_sampled_logprobs is not None :
1000+ if (
1001+ hasattr (request , "py_sampled_logprobs" )
1002+ and request .py_sampled_logprobs is not None
1003+ ):
10011004 sampled_logprobs = request .py_sampled_logprobs [:count ]
10021005 topk_log_probs_vals = request .py_topk_logprobs_vals [:count ] # At least k=1
10031006 topk_log_probs_indices = request .py_topk_logprobs_indices [:count ]
10041007
10051008 token_log_probs = []
1006- for step , (sampled_token , sampled_logprob , topk_tokens , topk_logprobs ) in enumerate (
1007- zip (sampled_tokens , sampled_logprobs , topk_log_probs_indices , topk_log_probs_vals )
1009+ for step , (
1010+ sampled_token ,
1011+ sampled_logprob ,
1012+ topk_tokens ,
1013+ topk_logprobs ,
1014+ ) in enumerate (
1015+ zip (
1016+ sampled_tokens ,
1017+ sampled_logprobs ,
1018+ topk_log_probs_indices ,
1019+ topk_log_probs_vals ,
1020+ )
10081021 ):
10091022 topk_tokens_list = topk_tokens .tolist ()
10101023 if sampled_token in topk_tokens_list :
@@ -1014,38 +1027,53 @@ def handle_logprobs(
10141027 # TODO: fix rank
10151028 rank = 2
10161029
1017- step_dict = {sampled_token : Logprob (logprob = sampled_logprob .item (), rank = rank )}
1030+ step_dict = {
1031+ sampled_token : Logprob (logprob = sampled_logprob .item (), rank = rank )
1032+ }
10181033 token_log_probs .append (step_dict )
10191034 else :
1020- raise ValueError ("py_sampled_logprobs not available when py_num_logprobs == 0" )
1035+ raise ValueError (
1036+ "py_sampled_logprobs not available when py_num_logprobs == 0"
1037+ )
10211038 else :
10221039 # Return top-K logprobs + logprob of sampled token
10231040 sampled_logprobs = request .py_sampled_logprobs [:count ]
10241041 topk_log_probs_vals = request .py_topk_logprobs_vals [:count ]
10251042 topk_log_probs_indices = request .py_topk_logprobs_indices [:count ]
10261043
10271044 token_log_probs = []
1028- for step , (sampled_token , sampled_logprob , topk_tokens , topk_logprobs ) in enumerate (
1029- zip (sampled_tokens , sampled_logprobs , topk_log_probs_indices , topk_log_probs_vals )
1045+ for step , (
1046+ sampled_token ,
1047+ sampled_logprob ,
1048+ topk_tokens ,
1049+ topk_logprobs ,
1050+ ) in enumerate (
1051+ zip (
1052+ sampled_tokens ,
1053+ sampled_logprobs ,
1054+ topk_log_probs_indices ,
1055+ topk_log_probs_vals ,
1056+ )
10301057 ):
10311058 step_dict = {}
10321059 topk_tokens_list = topk_tokens .tolist ()
10331060 topk_logprobs_list = topk_logprobs .tolist ()
10341061
1035- for rank_idx , (token , logprob ) in enumerate (zip (topk_tokens_list , topk_logprobs_list ), start = 1 ):
1062+ for rank_idx , (token , logprob ) in enumerate (
1063+ zip (topk_tokens_list , topk_logprobs_list ), start = 1
1064+ ):
10361065 step_dict [token ] = Logprob (logprob = logprob , rank = rank_idx )
10371066
10381067 if sampled_token not in step_dict :
10391068 # TODO: fix rank
10401069 step_dict [sampled_token ] = Logprob (
1041- logprob = sampled_logprob .item (),
1042- rank = len (topk_tokens_list ) + 1
1070+ logprob = sampled_logprob .item (), rank = len (topk_tokens_list ) + 1
10431071 )
10441072 token_log_probs .append (step_dict )
1045-
1073+
10461074 # Wrap in list for non-beam search (beam_width=1)
10471075 token_log_probs = [token_log_probs ]
1048-
1076+
10491077 request .py_result .append_log_probs (token_log_probs )
10501078
10511079 def finish_if_reason (
@@ -2518,16 +2546,18 @@ def _process_requests(
25182546 device = logits_cuda .device , non_blocking = True
25192547 )
25202548 logprobs_cuda = F .log_softmax (
2521- logits_cuda [logprobs_logit_indices_cuda ].to (dtype = torch .float32 , non_blocking = True ),
2549+ logits_cuda [logprobs_logit_indices_cuda ].to (
2550+ dtype = torch .float32 , non_blocking = True
2551+ ),
25222552 dim = - 1 ,
25232553 )
25242554
2525- max_k = max (max (1 , req .py_num_logprobs ) for req in requests if req .py_num_logprobs is not None )
2526- topk_vals_cuda , topk_indices_cuda = torch .topk (
2527- logprobs_cuda ,
2528- k = max_k ,
2529- dim = - 1
2555+ max_k = max (
2556+ max (1 , req .py_num_logprobs )
2557+ for req in requests
2558+ if req .py_num_logprobs is not None
25302559 )
2560+ topk_vals_cuda , topk_indices_cuda = torch .topk (logprobs_cuda , k = max_k , dim = - 1 )
25312561 # Use a single D2H copy to reduce overheads
25322562 topk_vals = torch .empty_like (topk_vals_cuda , device = "cpu" , pin_memory = True )
25332563 topk_indices = torch .empty_like (topk_indices_cuda , device = "cpu" , pin_memory = True )
@@ -2542,11 +2572,9 @@ def _process_requests(
25422572 # Store at least k=1 for all requests (including logprobs=0) to compute ranks
25432573 k_for_req = max (1 , req .py_num_logprobs )
25442574 # NB: Assigning views on memory which is being filled asynchronously
2545- req .py_topk_logprobs_vals = topk_vals [
2546- current_offset :next_offset , : k_for_req
2547- ]
2575+ req .py_topk_logprobs_vals = topk_vals [current_offset :next_offset , :k_for_req ]
25482576 req .py_topk_logprobs_indices = topk_indices [
2549- current_offset :next_offset , : k_for_req
2577+ current_offset :next_offset , :k_for_req
25502578 ]
25512579
25522580 # context requests do not have multiple input beams, but they need multiple output beams
@@ -2581,15 +2609,16 @@ def _process_requests(
25812609
25822610 # Build offsets for the GROUPED order
25832611 grouped_num_steps = req_num_steps [batch_req_indices ]
2584- grouped_offsets = torch .cat ([
2585- torch .zeros ((1 ,), dtype = torch .int32 , pin_memory = True ),
2586- grouped_num_steps .cumsum (dim = 0 )[:- 1 ]
2587- ])
2612+ grouped_offsets = torch .cat (
2613+ [
2614+ torch .zeros ((1 ,), dtype = torch .int32 , pin_memory = True ),
2615+ grouped_num_steps .cumsum (dim = 0 )[:- 1 ],
2616+ ]
2617+ )
25882618
25892619 # Reverse mapping: original_req_id → position in grouped result
25902620 req_to_grouped_pos = {
2591- orig_id .item (): grouped_pos
2592- for grouped_pos , orig_id in enumerate (batch_req_indices )
2621+ orig_id .item (): grouped_pos for grouped_pos , orig_id in enumerate (batch_req_indices )
25932622 }
25942623
25952624 for req_id in range (len (requests )):
@@ -2599,7 +2628,9 @@ def _process_requests(
25992628 if logprobs_idx == 0 :
26002629 start_offset = 0
26012630 else :
2602- start_offset = sum (req_num_steps [logprobs_req_indices [:logprobs_idx ]].tolist ())
2631+ start_offset = sum (
2632+ req_num_steps [logprobs_req_indices [:logprobs_idx ]].tolist ()
2633+ )
26032634
26042635 num_steps_this_req = req_num_steps [req_id ].item ()
26052636 end_offset = start_offset + num_steps_this_req
@@ -2610,8 +2641,12 @@ def _process_requests(
26102641
26112642 sampled_tokens_this_req = sampled_tokens_cuda [grouped_start :grouped_end ]
26122643
2613- step_indices = torch .arange (start_offset , end_offset , device = logprobs_cuda .device )
2614- sampled_logprobs_cuda = logprobs_cuda [step_indices , sampled_tokens_this_req .long ()]
2644+ step_indices = torch .arange (
2645+ start_offset , end_offset , device = logprobs_cuda .device
2646+ )
2647+ sampled_logprobs_cuda = logprobs_cuda [
2648+ step_indices , sampled_tokens_this_req .long ()
2649+ ]
26152650
26162651 sampled_logprobs_cpu = sampled_logprobs_cuda .to (device = "cpu" , non_blocking = True )
26172652 sampled_logprobs_list .append ((req_id , sampled_logprobs_cpu ))
0 commit comments