@@ -87,6 +87,7 @@ def make_model(self):
8787 model_name = self .model_name ,
8888 temperature = self .temperature ,
8989 max_tokens = self .max_new_tokens ,
90+ log_probs = self .log_probs ,
9091 )
9192
9293
@@ -100,6 +101,7 @@ def make_model(self):
100101 model_name = self .model_name ,
101102 temperature = self .temperature ,
102103 max_tokens = self .max_new_tokens ,
104+ log_probs = self .log_probs ,
103105 )
104106
105107
@@ -115,6 +117,7 @@ def make_model(self):
115117 temperature = self .temperature ,
116118 max_tokens = self .max_new_tokens ,
117119 deployment_name = self .deployment_name ,
120+ log_probs = self .log_probs ,
118121 )
119122
120123
@@ -142,6 +145,7 @@ def make_model(self):
142145 temperature = self .temperature ,
143146 max_new_tokens = self .max_new_tokens ,
144147 n_retry_server = self .n_retry_server ,
148+ log_probs = self .log_probs ,
145149 )
146150 elif self .backend == "vllm" :
147151 return VLLMChatModel (
@@ -232,6 +236,7 @@ def __init__(
232236 client_class = OpenAI ,
233237 client_args = None ,
234238 pricing_func = None ,
239+ log_probs = False ,
235240 ):
236241 assert max_retry > 0 , "max_retry should be greater than 0"
237242
@@ -240,6 +245,7 @@ def __init__(
240245 self .max_tokens = max_tokens
241246 self .max_retry = max_retry
242247 self .min_retry_wait_time = min_retry_wait_time
248+ self .log_probs = log_probs
243249
244250 # Get the API key from the environment variable if not provided
245251 if api_key_env_var :
@@ -286,6 +292,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
286292 n = n_samples ,
287293 temperature = temperature ,
288294 max_tokens = self .max_tokens ,
295+ log_probs = self .log_probs ,
289296 )
290297
291298 if completion .usage is None :
@@ -315,7 +322,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
315322 tracking .TRACKER .instance (input_tokens , output_tokens , cost )
316323
317324 if n_samples == 1 :
318- return AIMessage (completion .choices [0 ].message .content )
325+ res = AIMessage (completion .choices [0 ].message .content )
326+ if self .log_probs :
327+ res ["log_probs" ] = completion .choices [0 ].log_probs
328+ return res
319329 else :
320330 return [AIMessage (c .message .content ) for c in completion .choices ]
321331
@@ -335,6 +345,7 @@ def __init__(
335345 max_tokens = 100 ,
336346 max_retry = 4 ,
337347 min_retry_wait_time = 60 ,
348+ log_probs = False ,
338349 ):
339350 super ().__init__ (
340351 model_name = model_name ,
@@ -346,6 +357,7 @@ def __init__(
346357 api_key_env_var = "OPENAI_API_KEY" ,
347358 client_class = OpenAI ,
348359 pricing_func = tracking .get_pricing_openai ,
360+ log_probs = log_probs ,
349361 )
350362
351363
@@ -358,6 +370,7 @@ def __init__(
358370 max_tokens = 100 ,
359371 max_retry = 4 ,
360372 min_retry_wait_time = 60 ,
373+ log_probs = False ,
361374 ):
362375 client_args = {
363376 "base_url" : "https://openrouter.ai/api/v1" ,
@@ -373,6 +386,7 @@ def __init__(
373386 client_class = OpenAI ,
374387 client_args = client_args ,
375388 pricing_func = tracking .get_pricing_openrouter ,
389+ log_probs = log_probs ,
376390 )
377391
378392
@@ -386,6 +400,7 @@ def __init__(
386400 max_tokens = 100 ,
387401 max_retry = 4 ,
388402 min_retry_wait_time = 60 ,
403+ log_probs = False ,
389404 ):
390405 api_key = api_key or os .getenv ("AZURE_OPENAI_API_KEY" )
391406 endpoint = os .getenv ("AZURE_OPENAI_ENDPOINT" )
@@ -406,6 +421,7 @@ def __init__(
406421 client_class = AzureOpenAI ,
407422 client_args = client_args ,
408423 pricing_func = tracking .get_pricing_openai ,
424+ log_probs = log_probs ,
409425 )
410426
411427
@@ -419,8 +435,9 @@ def __init__(
419435 temperature : Optional [int ] = 1e-1 ,
420436 max_new_tokens : Optional [int ] = 512 ,
421437 n_retry_server : Optional [int ] = 4 ,
438+ log_probs : Optional [bool ] = False ,
422439 ):
423- super ().__init__ (model_name , base_model_name , n_retry_server )
440+ super ().__init__ (model_name , base_model_name , n_retry_server , log_probs )
424441 if temperature < 1e-3 :
425442 logging .warning ("Models might behave weirdly when temperature is too low." )
426443 self .temperature = temperature
@@ -429,7 +446,7 @@ def __init__(
429446 token = os .environ ["TGI_TOKEN" ]
430447
431448 client = InferenceClient (model = model_url , token = token )
432- self .llm = partial (client .text_generation , max_new_tokens = max_new_tokens )
449+ self .llm = partial (client .text_generation , max_new_tokens = max_new_tokens , details = log_probs )
433450
434451
435452class VLLMChatModel (ChatModel ):
0 commit comments