2020# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121# SOFTWARE.
2222
23- from dataclasses import dataclass
23+ from dataclasses import dataclass , asdict
2424from typing import Optional
2525
2626
@@ -57,24 +57,7 @@ def from_dict(cls, config_dict: dict):
5757 }
5858 }
5959 """
60- if "generation" not in config_dict :
61- return GenerationParameters ()
62- return GenerationParameters (
63- early_stopping = config_dict ["generation" ].get ("early_stopping" , None ),
64- repetition_penalty = config_dict ["generation" ].get ("repetition_penalty" , None ),
65- frequency_penalty = config_dict ["generation" ].get ("frequency_penalty" , None ),
66- length_penalty = config_dict ["generation" ].get ("length_penalty" , None ),
67- presence_penalty = config_dict ["generation" ].get ("presence_penalty" , None ),
68- max_new_tokens = config_dict ["generation" ].get ("max_new_tokens" , None ),
69- min_new_tokens = config_dict ["generation" ].get ("min_new_tokens" , None ),
70- seed = config_dict ["generation" ].get ("seed" , None ),
71- stop_tokens = config_dict ["generation" ].get ("stop_tokens" , None ),
72- temperature = config_dict ["generation" ].get ("temperature" , None ),
73- top_k = config_dict ["generation" ].get ("top_k" , None ),
74- min_p = config_dict ["generation" ].get ("min_p" , None ),
75- top_p = config_dict ["generation" ].get ("top_p" , None ),
76- truncate_prompt = config_dict ["generation" ].get ("truncate_prompt" , None ),
77- )
60+ return GenerationParameters (** config_dict .get ("generation" , {}))
7861
7962 def to_vllm_openai_dict (self ) -> dict :
8063 """Selects relevant generation and sampling parameters for vllm and openai models.
@@ -85,23 +68,7 @@ def to_vllm_openai_dict(self) -> dict:
8568 """
8669 # Task specific sampling params to set in model: n, best_of, use_beam_search
8770 # Generation specific params to set in model: logprobs, prompt_logprobs
88- args = {
89- "presence_penalty" : self .presence_penalty ,
90- "frequency_penalty" : self .frequency_penalty ,
91- "repetition_penalty" : self .repetition_penalty ,
92- "temperature" : self .temperature ,
93- "top_p" : self .top_p ,
94- "top_k" : self .top_k ,
95- "min_p" : self .min_p ,
96- "seed" : self .seed ,
97- "length_penalty" : self .length_penalty ,
98- "early_stopping" : self .early_stopping ,
99- "stop" : self .stop_tokens ,
100- "max_tokens" : self .max_new_tokens ,
101- "min_tokens" : self .min_new_tokens ,
102- "truncate_prompt_tokens" : self .truncate_prompt ,
103- }
104- return {k : v for k , v in args .items () if v is not None }
71+ return {k : v for k , v in asdict (self ).items () if v is not None }
10572
10673 def to_transformers_dict (self ) -> dict :
10774 """Selects relevant generation and sampling parameters for transformers models.
@@ -117,7 +84,7 @@ def to_transformers_dict(self) -> dict:
11784 args = {
11885 "max_new_tokens" : self .max_new_tokens ,
11986 "min_new_tokens" : self .min_new_tokens ,
120- "early_stopping" : self .early_stopping or False ,
87+ "early_stopping" : self .early_stopping ,
12188 "stop_strings" : self .stop_tokens ,
12289 "temperature" : self .temperature ,
12390 "top_k" : self .top_k ,
0 commit comments