@@ -145,6 +145,7 @@ def make_model(self):
145145 temperature = self .temperature ,
146146 max_new_tokens = self .max_new_tokens ,
147147 n_retry_server = self .n_retry_server ,
148+ log_probs = self .log_probs
148149 )
149150 else :
150151 raise ValueError (f"Backend { self .backend } is not supported" )
@@ -237,7 +238,7 @@ def __init__(
237238 self .max_tokens = max_tokens
238239 self .max_retry = max_retry
239240 self .min_retry_wait_time = min_retry_wait_time
240- self .logprobs = log_probs
241+ self .log_probs = log_probs
241242
242243 # Get the API key from the environment variable if not provided
243244 if api_key_env_var :
@@ -284,7 +285,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
284285 n = n_samples ,
285286 temperature = temperature ,
286287 max_tokens = self .max_tokens ,
287- logprobs = self .logprobs ,
288+ log_probs = self .log_probs ,
288289 )
289290
290291 if completion .usage is None :
@@ -315,8 +316,8 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
315316
316317 if n_samples == 1 :
317318 res = AIMessage (completion .choices [0 ].message .content )
318- if self .logprobs :
319- res ["logprobs " ] = completion .choices [0 ].logprobs
319+ if self .log_probs :
320+ res ["log_probs " ] = completion .choices [0 ].log_probs
320321 return res
321322 else :
322323 return [AIMessage (c .message .content ) for c in completion .choices ]
@@ -429,7 +430,7 @@ def __init__(
429430 n_retry_server : Optional [int ] = 4 ,
430431 log_probs : Optional [bool ] = False ,
431432 ):
432- super ().__init__ (model_name , base_model_name , n_retry_server )
433+ super ().__init__ (model_name , base_model_name , n_retry_server , log_probs )
433434 if temperature < 1e-3 :
434435 logging .warning ("Models might behave weirdly when temperature is too low." )
435436 self .temperature = temperature
0 commit comments