@@ -87,6 +87,7 @@ def make_model(self):
8787 model_name = self .model_name ,
8888 temperature = self .temperature ,
8989 max_tokens = self .max_new_tokens ,
90+ log_probs = self .log_probs ,
9091 )
9192
9293
@@ -100,6 +101,7 @@ def make_model(self):
100101 model_name = self .model_name ,
101102 temperature = self .temperature ,
102103 max_tokens = self .max_new_tokens ,
104+ log_probs = self .log_probs ,
103105 )
104106
105107
@@ -115,6 +117,7 @@ def make_model(self):
115117 temperature = self .temperature ,
116118 max_tokens = self .max_new_tokens ,
117119 deployment_name = self .deployment_name ,
120+ log_probs = self .log_probs ,
118121 )
119122
120123
@@ -225,6 +228,7 @@ def __init__(
225228 client_class = OpenAI ,
226229 client_args = None ,
227230 pricing_func = None ,
231+ log_probs = False ,
228232 ):
229233 assert max_retry > 0 , "max_retry should be greater than 0"
230234
@@ -233,6 +237,7 @@ def __init__(
233237 self .max_tokens = max_tokens
234238 self .max_retry = max_retry
235239 self .min_retry_wait_time = min_retry_wait_time
240+ self .logprobs = log_probs
236241
237242 # Get the API key from the environment variable if not provided
238243 if api_key_env_var :
@@ -279,6 +284,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
279284 n = n_samples ,
280285 temperature = temperature ,
281286 max_tokens = self .max_tokens ,
287+ logprobs = self .logprobs ,
282288 )
283289
284290 if completion .usage is None :
@@ -308,7 +314,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
308314 tracking .TRACKER .instance (input_tokens , output_tokens , cost )
309315
310316 if n_samples == 1 :
311- return AIMessage (completion .choices [0 ].message .content )
317+ res = AIMessage (completion .choices [0 ].message .content )
318+ if self .logprobs :
319+ res ["logprobs" ] = completion .choices [0 ].logprobs
320+ return res
312321 else :
313322 return [AIMessage (c .message .content ) for c in completion .choices ]
314323
@@ -328,6 +337,7 @@ def __init__(
328337 max_tokens = 100 ,
329338 max_retry = 4 ,
330339 min_retry_wait_time = 60 ,
340+ log_probs = False ,
331341 ):
332342 super ().__init__ (
333343 model_name = model_name ,
@@ -339,6 +349,7 @@ def __init__(
339349 api_key_env_var = "OPENAI_API_KEY" ,
340350 client_class = OpenAI ,
341351 pricing_func = tracking .get_pricing_openai ,
352+ log_probs = log_probs ,
342353 )
343354
344355
@@ -351,6 +362,7 @@ def __init__(
351362 max_tokens = 100 ,
352363 max_retry = 4 ,
353364 min_retry_wait_time = 60 ,
365+ log_probs = False ,
354366 ):
355367 client_args = {
356368 "base_url" : "https://openrouter.ai/api/v1" ,
@@ -366,6 +378,7 @@ def __init__(
366378 client_class = OpenAI ,
367379 client_args = client_args ,
368380 pricing_func = tracking .get_pricing_openrouter ,
381+ log_probs = log_probs ,
369382 )
370383
371384
@@ -379,6 +392,7 @@ def __init__(
379392 max_tokens = 100 ,
380393 max_retry = 4 ,
381394 min_retry_wait_time = 60 ,
395+ log_probs = False ,
382396 ):
383397 api_key = api_key or os .getenv ("AZURE_OPENAI_API_KEY" )
384398 endpoint = os .getenv ("AZURE_OPENAI_ENDPOINT" )
@@ -399,6 +413,7 @@ def __init__(
399413 client_class = AzureOpenAI ,
400414 client_args = client_args ,
401415 pricing_func = tracking .get_pricing_openai ,
416+ log_probs = log_probs ,
402417 )
403418
404419
@@ -412,6 +427,7 @@ def __init__(
412427 temperature : Optional [int ] = 1e-1 ,
413428 max_new_tokens : Optional [int ] = 512 ,
414429 n_retry_server : Optional [int ] = 4 ,
430+ log_probs : Optional [bool ] = False ,
415431 ):
416432 super ().__init__ (model_name , base_model_name , n_retry_server )
417433 if temperature < 1e-3 :
@@ -422,4 +438,4 @@ def __init__(
422438 token = os .environ ["TGI_TOKEN" ]
423439
424440 client = InferenceClient (model = model_url , token = token )
425- self .llm = partial (client .text_generation , max_new_tokens = max_new_tokens )
441+ self .llm = partial (client .text_generation , max_new_tokens = max_new_tokens , details = log_probs )
0 commit comments