Skip to content

Commit 9ba430f

Browse files
rolshovenNathanHB
andauthored
Added backend_options parameter to llm judges. (#963)
* Added `backend_options` parameter to llm judges. Currently only used for litellm backend but can be extended to other backends as well. Allows to specify whether to use caching or not, the number of concurrent requests, and whether the token output budget should be increased for reasoning models. * Implemented changes from code review * Ran pre-commit hooks --------- Co-authored-by: Nathan Habib <[email protected]>
1 parent 2dc1788 commit 9ba430f

File tree

2 files changed

+50
-18
lines changed

2 files changed

+50
-18
lines changed

src/lighteval/metrics/metrics_sample.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,7 @@ def __init__(
950950
url: str | None = None,
951951
hf_provider: str | None = None,
952952
max_tokens: int | None = None,
953+
backend_options: dict | None = None,
953954
) -> None:
954955
logger.debug(f"Initializing JudgeLLM with backend: {judge_backend}, model: {judge_model_name}")
955956

@@ -996,6 +997,7 @@ def __init__(
996997
url=url,
997998
hf_provider=hf_provider,
998999
max_tokens=max_tokens,
1000+
backend_options=backend_options,
9991001
)
10001002

10011003
def compute(self, responses: list[ModelResponse], docs: list[Doc], **kwargs) -> list:

src/lighteval/metrics/utils/llm_as_judge.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import logging
2626
import time
2727
from concurrent.futures import ThreadPoolExecutor
28+
from dataclasses import dataclass
2829
from typing import Callable, Literal, Optional
2930

3031
from huggingface_hub import AsyncInferenceClient, InferenceTimeoutError
@@ -45,28 +46,40 @@
4546
DEFAULT_FORMAT = {"type": "text"}
4647

4748

48-
class JudgeLM:
49-
"""A class representing a judge for evaluating answers using either the OpenAI or Transformers library.
49+
@dataclass
50+
class LitellmBackendOptions:
51+
"""Options for the LiteLLM judge backend with default values.
5052
51-
Args:
52-
model (str): The name of the model.
53-
templates (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt.
54-
process_judge_response (Callable): A function for processing the judge's response.
55-
judge_backend (Literal["openai", "transformers", "tgi", "vllm"]): The backend for the judge.
56-
url (str | None): The URL for the OpenAI API.
57-
api_key (str | None): The API key for the OpenAI API (either OpenAI or HF key).
53+
Attributes:
54+
caching (bool): Whether to enable caching for the API responses. Defaults to True.
55+
concurrent_requests (int): The maximum number of concurrent requests to the API. Defaults to 10.
56+
increase_max_tokens_for_reasoning (bool): Whether to increase the max tokens for certain reasoning
57+
models. Defaults to True.
58+
"""
59+
60+
caching: bool = True
61+
concurrent_requests: int = 10
62+
63+
# Increases max_tokens depending on the model used, see implementation below
64+
increase_max_tokens_for_reasoning: bool = True
65+
66+
67+
class JudgeLM:
68+
"""A class representing a judge for evaluating answers using either the chosen backend.
5869
5970
Attributes:
6071
model (str): The name of the model.
61-
template (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt.
62-
API_MAX_RETRY (int): The maximum number of retries for the API.
63-
API_RETRY_SLEEP (int): The time to sleep between retries.
64-
client (OpenAI | None): The OpenAI client.
65-
pipe (LLM | AutoModel | None): The Transformers or vllm pipeline.
72+
templates (Callable): A function taking into account the question, options, answer, and gold and returning the judge prompt.
6673
process_judge_response (Callable): A function for processing the judge's response.
74+
judge_backend (Literal["litellm", "openai", "transformers", "tgi", "vllm", "inference-providers"]): The backend for the judge.
6775
url (str | None): The URL for the OpenAI API.
6876
api_key (str | None): The API key for the OpenAI API (either OpenAI or HF key).
69-
backend (Literal["openai", "transformers", "tgi", "vllm"]): The backend for the judge
77+
max_tokens (int): The maximum number of tokens to generate. Defaults to 512.
78+
response_format (BaseModel | None): The format of the response from the API, used for the OpenAI and TGI backend.
79+
hf_provider (Literal["black-forest-labs", "cerebras", "cohere", "fal-ai", "fireworks-ai",
80+
"inference-providers", "hyperbolic", "nebius", "novita", "openai", "replicate", "sambanova", "together"] | None):
81+
The HuggingFace provider when using the inference-providers backend.
82+
backend_options (dict | None): Options for the backend. Currently only supported for litellm.
7083
7184
Methods:
7285
evaluate_answer: Evaluates an answer using the OpenAI API or Transformers library.
@@ -103,6 +116,7 @@ def __init__(
103116
"together",
104117
]
105118
] = None,
119+
backend_options: dict | None = None,
106120
):
107121
self.model = model
108122
self.template = templates
@@ -122,6 +136,12 @@ def __init__(
122136

123137
self.response_format = response_format if not None else DEFAULT_FORMAT
124138

139+
self.backend_options = backend_options or {}
140+
141+
# Override backend options dictionary with the corresponding dataclass to ensure all specified options are valid
142+
if judge_backend == "litellm":
143+
self.backend_options = LitellmBackendOptions(**self.backend_options)
144+
125145
# Validate that hf_provider is specified when using inference-providers backend
126146
if self.backend == "inference-providers" and self.hf_provider is None:
127147
raise ValueError("When using 'inference-providers' as backend, you must specify an 'hf_provider'")
@@ -286,12 +306,22 @@ def __call_vllm(self, prompt):
286306
def __call_litellm(self, prompts):
287307
import litellm
288308

309+
if self.backend_options.caching:
310+
from litellm.caching.caching import Cache, LiteLLMCacheType
311+
312+
litellm.cache = Cache(type=LiteLLMCacheType.DISK)
313+
314+
# Automatically drop parameters that are not supported by the currently used inference API
315+
litellm.drop_params = True
316+
289317
def __call_api(prompt):
290318
error_message = "ERROR: Failed to get response from the API."
291319
for _ in range(self.API_MAX_RETRY):
292320
try:
293-
max_new_tokens = 512
294-
if "o1" in self.model or "o3" in self.model or "R1" in self.model:
321+
max_new_tokens = self.max_tokens
322+
323+
is_reasoning_model = "o1" in self.model or "o3" in self.model or "R1" in self.model
324+
if is_reasoning_model and self.backend_options.increase_max_tokens_for_reasoning:
295325
max_new_tokens = min(max_new_tokens * 10, 32000)
296326

297327
kwargs = {
@@ -319,7 +349,7 @@ def __call_api(prompt):
319349
return error_message
320350

321351
results = []
322-
with ThreadPoolExecutor(100) as executor:
352+
with ThreadPoolExecutor(self.backend_options.concurrent_requests) as executor:
323353
for entry in tqdm(executor.map(__call_api, prompts), total=len(prompts)):
324354
results.append(entry)
325355

0 commit comments

Comments
 (0)