2121# SOFTWARE.
2222
2323import logging
24- import os
2524import time
2625from concurrent .futures import ThreadPoolExecutor
2726from dataclasses import dataclass
2827from typing import Optional
2928
29+ import yaml
3030from tqdm import tqdm
3131
3232from lighteval .data import GenerativeTaskDataset
3333from lighteval .models .abstract_model import LightevalModel
3434from lighteval .models .endpoints .endpoint_model import ModelInfo
35+ from lighteval .models .model_input import GenerationParameters
3536from lighteval .models .model_output import (
3637 GenerativeResponse ,
3738 LoglikelihoodResponse ,
6364@dataclass
6465class LiteLLMModelConfig :
6566 model : str
67+ provider : Optional [str ] = None
68+ base_url : Optional [str ] = None
69+ api_key : Optional [str ] = None
70+ generation_parameters : GenerationParameters = None
71+
72+ def __post_init__ (self ):
73+ if self .generation_parameters is None :
74+ self .generation_parameters = GenerationParameters ()
75+
76+ @classmethod
77+ def from_path (cls , path ):
78+ with open (path , "r" ) as f :
79+ config = yaml .safe_load (f )["model" ]
80+
81+ model = config ["base_params" ]["model_name" ]
82+ provider = config ["base_params" ].get ("provider" , None )
83+ base_url = config ["base_params" ].get ("base_url" , None )
84+ api_key = config ["base_params" ].get ("api_key" , None )
85+ generation_parameters = GenerationParameters .from_dict (config )
86+ return cls (
87+ model = model ,
88+ provider = provider ,
89+ base_url = base_url ,
90+ generation_parameters = generation_parameters ,
91+ api_key = api_key ,
92+ )
6693
6794
6895class LiteLLMClient (LightevalModel ):
@@ -79,15 +106,17 @@ def __init__(self, config, env_config) -> None:
79106 model_dtype = None ,
80107 model_size = "" ,
81108 )
82- self .provider = config .model .split ("/" )[0 ]
83- self .base_url = os .getenv (f"{ self .provider .upper ()} _BASE_URL" , None )
109+ self .model = config .model
110+ self .provider = config .provider or config .model .split ("/" )[0 ]
111+ self .base_url = config .base_url
112+ self .api_key = config .api_key
113+ self .generation_parameters = config .generation_parameters
114+
84115 self .API_MAX_RETRY = 5
85116 self .API_RETRY_SLEEP = 3
86117 self .API_RETRY_MULTIPLIER = 2
87118 self .CONCURENT_CALLS = 20 # 100 leads to hitting Anthropic rate limits
88- self .TEMPERATURE = 0.3
89- self .TOP_P = 0.95
90- self .model = config .model
119+
91120 self ._tokenizer = encode
92121 self .pairwise_tokenization = False
93122 litellm .drop_params = True
@@ -126,18 +155,19 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se
126155 kwargs = {
127156 "model" : self .model ,
128157 "messages" : prompt ,
129- "max_completion_tokens" : max_new_tokens ,
130158 "logprobs" : return_logits if self .provider == "openai" else None ,
131159 "base_url" : self .base_url ,
132160 "n" : num_samples ,
133161 "caching" : True ,
162+ "api_key" : self .api_key ,
134163 }
135164 if "o1" in self .model :
136165 logger .warning ("O1 models do not support temperature, top_p, stop sequence. Disabling." )
137166 else :
138- kwargs ["temperature" ] = self .TEMPERATURE
139- kwargs ["top_p" ] = self .TOP_P
140- kwargs ["stop" ] = stop_sequence
167+ kwargs .update (self .generation_parameters .to_litellm_dict ())
168+
169+ if kwargs .get ("max_completion_tokens" , None ) is None :
170+ kwargs ["max_completion_tokens" ] = max_new_tokens
141171
142172 response = litellm .completion (** kwargs )
143173
0 commit comments