11"""Sampling parameters for text generation."""
22import os
33from typing import List , Optional , Union , Tuple
4-
4+ from transformers import GenerationConfig
55from .req_id_generator import MAX_BEST_OF
66
77_SAMPLING_EPS = 1e-5
1010
1111
1212class SamplingParams :
13+
14+ _do_sample : bool = (False ,)
15+ _presence_penalty : float = (0.0 ,)
16+ _frequency_penalty : float = (0.0 ,)
17+ _repetition_penalty : float = (1.0 ,)
18+ _temperature : float = (1.0 ,)
19+ _top_p : float = (1.0 ,)
20+ _top_k : int = (- 1 ,) # -1 is for all
21+
1322 def __init__ (
1423 self ,
1524 best_of : int = 1 ,
1625 n : int = None , # number of results
17- do_sample : bool = False ,
18- presence_penalty : float = 0.0 ,
19- frequency_penalty : float = 0.0 ,
20- repetition_penalty : float = 1.0 ,
26+ do_sample : bool = None ,
27+ presence_penalty : float = None ,
28+ frequency_penalty : float = None ,
29+ repetition_penalty : float = None ,
2130 exponential_decay_length_penalty : Tuple [int , float ] = (1 , 1.0 ),
22- temperature : float = 1.0 ,
23- top_p : float = 1.0 ,
24- top_k : int = - 1 , # -1 is for all
31+ temperature : float = None ,
32+ top_p : float = None ,
33+ top_k : int = None , # -1 is for all
2534 ignore_eos : bool = False ,
2635 max_new_tokens : int = 16 ,
2736 min_new_tokens : int = 1 ,
@@ -46,14 +55,18 @@ def __init__(
4655 ) -> None :
4756 self .best_of = best_of
4857 self .n = n
49- self .do_sample = do_sample
50- self .presence_penalty = presence_penalty
51- self .frequency_penalty = frequency_penalty
52- self .repetition_penalty = repetition_penalty
58+ self .do_sample = do_sample if do_sample is not None else SamplingParams ._do_sample
59+ self .presence_penalty = presence_penalty if presence_penalty is not None else SamplingParams ._presence_penalty
60+ self .frequency_penalty = (
61+ frequency_penalty if frequency_penalty is not None else SamplingParams ._frequency_penalty
62+ )
63+ self .repetition_penalty = (
64+ repetition_penalty if repetition_penalty is not None else SamplingParams ._repetition_penalty
65+ )
5366 self .exponential_decay_length_penalty = exponential_decay_length_penalty
54- self .temperature = temperature
55- self .top_p = top_p
56- self .top_k = top_k
67+ self .temperature = temperature if temperature is not None else SamplingParams . _temperature
68+ self .top_p = top_p if top_p is not None else SamplingParams . _top_p
69+ self .top_k = top_k if top_k is not None else SamplingParams . _top_k
5770 self .ignore_eos = ignore_eos
5871 self .max_new_tokens = max_new_tokens
5972 self .min_new_tokens = min_new_tokens
@@ -81,6 +94,20 @@ def __init__(
8194 self .n = self .best_of
8295 return
8396
97+ @classmethod
98+ def load_generation_cfg (cls , weight_dir ):
99+ try :
100+ generation_cfg = GenerationConfig .from_pretrained (weight_dir , trust_remote_code = True ).to_dict ()
101+ cls ._do_sample = generation_cfg .get ("do_sample" , False )
102+ cls ._presence_penalty = generation_cfg .get ("presence_penalty" , 0.0 )
103+ cls ._frequency_penalty = generation_cfg .get ("frequency_penalty" , 0.0 )
104+ cls ._repetition_penalty = generation_cfg .get ("repetition_penalty" , 1.0 )
105+ cls ._temperature = generation_cfg .get ("temperature" , 1.0 )
106+ cls ._top_p = generation_cfg .get ("top_p" , 1.0 )
107+ cls ._top_k = generation_cfg .get ("top_k" , - 1 )
108+ except :
109+ pass
110+
84111 def verify (self ):
85112 if self .best_of <= 0 or self .best_of > MAX_BEST_OF :
86113 raise ValueError (f"need 0 < best_of <= { MAX_BEST_OF } , but get { self .best_of } " )
0 commit comments