@@ -229,6 +229,13 @@ class SpecMetadata:
229229 # whether the spec-dec mode is a dynamic tree.
230230 is_spec_dec_dynamic_tree : bool = False
231231
232+ # For non-greedy sampling on 1-model.
233+ allow_advanced_sampling : bool = False
234+ # Sampling parameters for non-greedy sampling (per-request)
235+ temperatures : Optional [torch .Tensor ] = None
236+ top_ks : Optional [torch .Tensor ] = None
237+ top_ps : Optional [torch .Tensor ] = None
238+
232239 def __post_init__ (self ):
233240 pass
234241
@@ -264,3 +271,83 @@ def maybe_capture_hidden_states(self, layer_id: int,
264271 Some spec decode algorithms require hidden states from the target
265272 model. Use this method to record them. By default, does nothing.
266273 """
274+
275+ def populate_sampling_params_for_one_model (
276+ self , requests : list ["LlmRequest" ]) -> None :
277+ """
278+ Set up topp/topk/temperatures for 1-model sampler.
279+ """
280+ from tensorrt_llm ._torch .pyexecutor .llm_request import LlmRequestState
281+ from tensorrt_llm .sampling_params import SamplingParams
282+
283+ if not self .allow_advanced_sampling or not self .spec_dec_mode .use_one_engine (
284+ ):
285+ return
286+
287+ if self .temperatures is None :
288+ # Ensures determinism across ranks.
289+ torch .manual_seed (0 )
290+
291+ temperatures = []
292+ top_ks = []
293+ top_ps = []
294+
295+ # Need to use a very small value for temperature when disabled to avoid division by 0
296+ DISABLE_TEMP_VAL = 1e-5
297+ # Very large values disable topk.
298+ DISABLE_TOPK_VAL = torch .iinfo (torch .int32 ).max
299+ DISABLE_TOPP_VAL = 1.0
300+
301+ for request in requests :
302+ sampling_config = request .sampling_config
303+ temp = sampling_config .temperature
304+ temp_val = temp [0 ] if temp is not None and len (temp ) > 0 else None
305+
306+ tk = sampling_config .top_k
307+ tk_val = tk [0 ] if tk is not None and len (tk ) > 0 else None
308+
309+ tp = sampling_config .top_p
310+ tp_val = tp [0 ] if tp is not None and len (tp ) > 0 else None
311+
312+ # Context requests have no draft tokens yet.
313+ num_tokens = 1 + self .max_draft_len if request .state == LlmRequestState .GENERATION_IN_PROGRESS else 1
314+
315+ is_greedy = SamplingParams .params_imply_greedy_decoding (
316+ temperature = temp_val ,
317+ top_k = tk_val ,
318+ top_p = tp_val ,
319+ use_beam_search = False )
320+
321+ temp_val = DISABLE_TEMP_VAL if is_greedy or temp_val is None or temp_val == 0 else temp_val
322+ tk_val = DISABLE_TOPK_VAL if is_greedy or tk_val is None or tk_val <= 0 else tk_val
323+ tp_val = DISABLE_TOPP_VAL if is_greedy or tp_val is None else tp_val
324+
325+ temperatures .extend (temp_val for _ in range (num_tokens ))
326+ top_ks .extend (tk_val for _ in range (num_tokens ))
327+ top_ps .extend (tp_val for _ in range (num_tokens ))
328+
329+ if self .temperatures is None :
330+ self .temperatures = torch .ones (
331+ (self .max_draft_len + 1 ) * self .max_num_requests ,
332+ dtype = torch .float32 ,
333+ device = 'cuda' )
334+ self .top_ks = torch .zeros (
335+ (self .max_draft_len + 1 ) * self .max_num_requests ,
336+ dtype = torch .int32 ,
337+ device = 'cuda' )
338+ self .top_ps = torch .ones (
339+ (self .max_draft_len + 1 ) * self .max_num_requests ,
340+ dtype = torch .float32 ,
341+ device = 'cuda' )
342+
343+ self .temperatures [:len (temperatures )].copy_ (torch .tensor (
344+ temperatures , dtype = torch .float32 , pin_memory = True ),
345+ non_blocking = True )
346+ self .top_ks [:len (top_ks )].copy_ (torch .tensor (top_ks ,
347+ dtype = torch .int32 ,
348+ pin_memory = True ),
349+ non_blocking = True )
350+ self .top_ps [:len (top_ps )].copy_ (torch .tensor (top_ps ,
351+ dtype = torch .float32 ,
352+ pin_memory = True ),
353+ non_blocking = True )
0 commit comments