@@ -14,6 +14,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
1414 b_top_ks ,
1515 b_length_penalty_param ,
1616 b_mask_eos_reqs ,
17+ is_all_greedy ,
1718 ) = _get_post_sample_tensors (reqs )
1819 eos_ids = torch .tensor (eos_id , dtype = torch .int32 , device = "cpu" , pin_memory = True ).cuda (non_blocking = True )
1920
@@ -61,7 +62,12 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
6162 logits .div_ (b_temperatures .view ((- 1 , 1 )))
6263 probs = torch .softmax (logits , dim = - 1 )
6364
64- if get_env_start_args ().sampling_backend == "triton" :
65+ if is_all_greedy :
66+ batch_next_token_ids = torch .argmax (logits , - 1 )
67+ batch_next_token_probs = torch .gather (probs , dim = 1 , index = batch_next_token_ids .view (- 1 , 1 ))
68+ return batch_next_token_ids .view (- 1 ), torch .log (batch_next_token_probs ).view (- 1 )
69+
70+ elif get_env_start_args ().sampling_backend == "triton" :
6571 probs_sort , probs_idx = _top_p_top_k (probs , b_top_ps , b_top_ks )
6672 sampled_index = torch .multinomial (probs_sort , num_samples = 1 , replacement = True )
6773 next_token_ids = torch .gather (probs_idx , dim = 1 , index = sampled_index )
@@ -104,6 +110,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
104110 top_ks : List [int ] = []
105111 length_penalty_param : List [int ] = []
106112 mask_eos_reqs : List [bool ] = []
113+ is_all_greedy = True
114+
107115 for i , req_obj in enumerate (reqs ):
108116 sample_param = req_obj .sampling_param
109117 shm_param = sample_param .shm_param
@@ -114,7 +122,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
114122
115123 temperatures .append (shm_param .temperature )
116124 top_ps .append (shm_param .top_p )
117- top_ks .append (shm_param .top_k )
125+ top_k_val = shm_param .top_k
126+ top_ks .append (top_k_val )
127+ if top_k_val > 1 :
128+ is_all_greedy = False
118129 req_idxes .append (req_obj .req_idx )
119130
120131 req_idxes_cpu = torch .tensor (req_idxes , dtype = torch .int32 , device = "cpu" , pin_memory = True )
@@ -131,4 +142,5 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
131142 top_ks_cpu .cuda (non_blocking = True ),
132143 length_penalty_param_cpu .cuda (non_blocking = True ),
133144 mask_eos_reqs_cpu .cuda (non_blocking = True ),
145+ is_all_greedy ,
134146 )
0 commit comments