2525import logging
2626import time
2727from concurrent .futures import ThreadPoolExecutor
28+ from dataclasses import dataclass
2829from typing import Callable , Literal , Optional
2930
3031from huggingface_hub import AsyncInferenceClient , InferenceTimeoutError
4546DEFAULT_FORMAT = {"type" : "text" }
4647
4748
48- class JudgeLM :
49- """A class representing a judge for evaluating answers using either the OpenAI or Transformers library.
49+ @dataclass
50+ class LitellmBackendOptions :
51+ """Options for the LiteLLM judge backend with default values.
5052
51- Args:
52- model (str): The name of the model.
53- templates (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt.
54- process_judge_response (Callable): A function for processing the judge's response.
55- judge_backend (Literal["openai", "transformers", "tgi", "vllm"]): The backend for the judge.
56- url (str | None): The URL for the OpenAI API.
57- api_key (str | None): The API key for the OpenAI API (either OpenAI or HF key).
53+ Attributes:
54+ caching (bool): Whether to enable caching for the API responses. Defaults to True.
55+ concurrent_requests (int): The maximum number of concurrent requests to the API. Defaults to 10.
56+ increase_max_tokens_for_reasoning (bool): Whether to increase the max tokens for certain reasoning
57+ models. Defaults to True.
58+ """
59+
60+ caching : bool = True
61+ concurrent_requests : int = 10
62+
63+ # Increases max_tokens depending on the model used, see implementation below
64+ increase_max_tokens_for_reasoning : bool = True
65+
66+
67+ class JudgeLM :
68+ """A class representing a judge for evaluating answers using either the chosen backend.
5869
5970 Attributes:
6071 model (str): The name of the model.
61- template (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt.
62- API_MAX_RETRY (int): The maximum number of retries for the API.
63- API_RETRY_SLEEP (int): The time to sleep between retries.
64- client (OpenAI | None): The OpenAI client.
65- pipe (LLM | AutoModel | None): The Transformers or vllm pipeline.
72+ templates (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt.
6673 process_judge_response (Callable): A function for processing the judge's response.
74+ judge_backend (Literal["litellm", "openai", "transformers", "tgi", "vllm", "inference-providers"]): The backend for the judge.
6775 url (str | None): The URL for the OpenAI API.
6876 api_key (str | None): The API key for the OpenAI API (either OpenAI or HF key).
69- backend (Literal["openai", "transformers", "tgi", "vllm"]): The backend for the judge
77+ max_tokens (int): The maximum number of tokens to generate. Defaults to 512.
78+ response_format (BaseModel | None): The format of the response from the API, used for the OpenAI and TGI backend.
79+ hf_provider (Literal["black-forest-labs", "cerebras", "cohere", "fal-ai", "fireworks-ai",
80+ "inference-providers", "hyperbolic", "nebius", "novita", "openai", "replicate", "sambanova", "together"] | None):
81+ The HuggingFace provider when using the inference-providers backend.
82+ backend_options (dict | None): Options for the backend. Currently only supported for litellm.
7083
7184 Methods:
7285 evaluate_answer: Evaluates an answer using the OpenAI API or Transformers library.
@@ -103,6 +116,7 @@ def __init__(
103116 "together" ,
104117 ]
105118 ] = None ,
119+ backend_options : dict | None = None ,
106120 ):
107121 self .model = model
108122 self .template = templates
@@ -122,6 +136,12 @@ def __init__(
122136
123137 self .response_format = response_format if not None else DEFAULT_FORMAT
124138
139+ self .backend_options = backend_options or {}
140+
141+ # Override backend options dictionary with the corresponding dataclass to ensure all specified options are valid
142+ if judge_backend == "litellm" :
143+ self .backend_options = LitellmBackendOptions (** self .backend_options )
144+
125145 # Validate that hf_provider is specified when using inference-providers backend
126146 if self .backend == "inference-providers" and self .hf_provider is None :
127147 raise ValueError ("When using 'inference-providers' as backend, you must specify an 'hf_provider'" )
@@ -286,12 +306,22 @@ def __call_vllm(self, prompt):
286306 def __call_litellm (self , prompts ):
287307 import litellm
288308
309+ if self .backend_options .caching :
310+ from litellm .caching .caching import Cache , LiteLLMCacheType
311+
312+ litellm .cache = Cache (type = LiteLLMCacheType .DISK )
313+
314+ # Automatically drop parameters that are not supported by the currently used inference API
315+ litellm .drop_params = True
316+
289317 def __call_api (prompt ):
290318 error_message = "ERROR: Failed to get response from the API."
291319 for _ in range (self .API_MAX_RETRY ):
292320 try :
293- max_new_tokens = 512
294- if "o1" in self .model or "o3" in self .model or "R1" in self .model :
321+ max_new_tokens = self .max_tokens
322+
323+ is_reasoning_model = "o1" in self .model or "o3" in self .model or "R1" in self .model
324+ if is_reasoning_model and self .backend_options .increase_max_tokens_for_reasoning :
295325 max_new_tokens = min (max_new_tokens * 10 , 32000 )
296326
297327 kwargs = {
@@ -319,7 +349,7 @@ def __call_api(prompt):
319349 return error_message
320350
321351 results = []
322- with ThreadPoolExecutor (100 ) as executor :
352+ with ThreadPoolExecutor (self . backend_options . concurrent_requests ) as executor :
323353 for entry in tqdm (executor .map (__call_api , prompts ), total = len (prompts )):
324354 results .append (entry )
325355
0 commit comments