2626from concurrent .futures import ThreadPoolExecutor
2727from typing import Callable , Literal
2828
29+ from pydantic import BaseModel
2930from tqdm import tqdm
3031
3132from lighteval .utils .imports import is_litellm_available , is_openai_available , is_vllm_available
33+ from lighteval .utils .utils import as_list
3234
3335
3436logging .getLogger ("openai" ).setLevel (logging .ERROR )
3537logging .getLogger ("httpx" ).setLevel (logging .ERROR )
3638logger = logging .getLogger (__name__ )
3739
3840
41+ DEFAULT_FORMAT = {"type" : "text" }
42+
43+
3944class JudgeLM :
4045 """
4146 A class representing a judge for evaluating answers using either the OpenAI or Transformers library.
@@ -76,6 +81,7 @@ def __init__(
7681 judge_backend : Literal ["litellm" , "openai" , "transformers" , "tgi" , "vllm" ],
7782 url : str | None = None ,
7883 api_key : str | None = None ,
84+ response_format : BaseModel = None ,
7985 ):
8086 self .model = model
8187 self .template = templates
@@ -91,6 +97,8 @@ def __init__(
9197 self .api_key = api_key
9298 self .backend = judge_backend
9399
100+ self .response_format = response_format if not None else DEFAULT_FORMAT
101+
94102 def __lazy_load_client (self ):
95103 match self .backend :
96104 # Wether we use openai or TGI models, we go through the openai API
@@ -232,7 +240,7 @@ def __call_api(prompt):
232240
233241 def __call_api_parallel (self , prompts ):
234242 results = []
235- with ThreadPoolExecutor (100 ) as executor :
243+ with ThreadPoolExecutor (10 ) as executor :
236244 for entry in tqdm (executor .map (self .__call_api , prompts ), total = len (prompts )):
237245 results .append (entry )
238246
@@ -244,16 +252,34 @@ def __call_api_parallel(self, prompts):
244252 def __call_api (self , prompt ):
245253 for _ in range (self .API_MAX_RETRY ):
246254 try :
247- response = self .client .chat .completions .create (
255+ # Base model
256+ response = self .client .beta .chat .completions .parse (
248257 model = self .model ,
249- messages = prompt ,
250- response_format = {"type" : "text" },
251- max_tokens = 512 ,
258+ messages = as_list (prompt ),
259+ response_format = self .response_format ,
260+ max_tokens = 4096 ,
261+ temperature = 0.0 ,
252262 n = 1 ,
253263 )
254- text = response .choices [0 ].message .content
255- return text
264+ answer = response .choices [0 ].message .parsed
265+ return answer
266+ except TypeError :
267+ try :
268+ # Finetune
269+ response = self .client .chat .completions .create (
270+ model = self .model ,
271+ messages = as_list (prompt ),
272+ response_format = self .response_format ,
273+ max_tokens = 512 ,
274+ n = 1 ,
275+ )
276+ text = response .choices [0 ].message .content
277+ return text
278+ except Exception as e :
279+ logger .warning (f"{ type (e ), e } " )
280+ time .sleep (self .API_RETRY_SLEEP )
256281 except Exception as e :
257282 logger .warning (f"{ type (e ), e } " )
258283 time .sleep (self .API_RETRY_SLEEP )
284+
259285 raise Exception ("Failed to get response from the API" )
0 commit comments