@@ -261,7 +261,7 @@ def __init__(
261261 ** client_args ,
262262 )
263263
264- def __call__ (self , messages : list [dict ], n_samples : int = 1 ) -> dict :
264+ def __call__ (self , messages : list [dict ], n_samples : int = 1 , temperature : float = None ) -> dict :
265265 # Initialize retry tracking attributes
266266 self .retries = 0
267267 self .success = False
@@ -271,12 +271,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
271271 e = None
272272 for itr in range (self .max_retry ):
273273 self .retries += 1
274+ temperature = temperature if temperature is not None else self .temperature
274275 try :
275276 completion = self .client .chat .completions .create (
276277 model = self .model_name ,
277278 messages = messages ,
278279 n = n_samples ,
279- temperature = self . temperature ,
280+ temperature = temperature ,
280281 max_tokens = self .max_tokens ,
281282 )
282283
@@ -414,11 +415,10 @@ def __init__(
414415 super ().__init__ (model_name , n_retry_server )
415416 if temperature < 1e-3 :
416417 logging .warning ("Models might behave weirdly when temperature is too low." )
418+ self .temperature = temperature
417419
418420 if token is None :
419421 token = os .environ ["TGI_TOKEN" ]
420422
421423 client = InferenceClient (model = model_url , token = token )
422- self .llm = partial (
423- client .text_generation , temperature = temperature , max_new_tokens = max_new_tokens
424- )
424+ self .llm = partial (client .text_generation , max_new_tokens = max_new_tokens )
0 commit comments