Skip to content

Commit e67ed9c

Browse files
authored
Added custom model inference. (huggingface#437)
Enables the evaluation of any system in the user's control. Fixes [Issue 430](huggingface#430). Try with ``` python -m lighteval custom google-translate /path/to/google_translate_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10 ``` google_translate_model.py ``` import logging from typing import Optional from tqdm import tqdm from transformers import AutoTokenizer from lighteval.data import GenerativeTaskDataset from lighteval.models.abstract_model import LightevalModel, ModelInfo from lighteval.models.model_output import ( GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse, ) from lighteval.tasks.requests import ( GreedyUntilRequest, LoglikelihoodRequest, LoglikelihoodRollingRequest, LoglikelihoodSingleTokenRequest, ) logger = logging.getLogger(__name__) class GoogleTranslateClient(LightevalModel): def __init__(self, config, env_config) -> None: self.model = config.model self.model_definition_file_path = config.model_definition_file_path self.model_info = ModelInfo( model_name=config.model, model_sha="", model_dtype=None, model_size="", ) self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility import httpcore # Needed to fix some googletrans bug # https://stackoverflow.com/questions/72796594/attributeerror-module-httpcore-has-no-attribute-synchttptransport#comment136664963_77334618 setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy') from googletrans import Translator self.translator = Translator() def greedy_until( self, requests: list[GreedyUntilRequest], override_bs: Optional[int] = None, ) -> list[GenerativeResponse]: """ Generates responses using a greedy decoding strategy until certain ending conditions are met. Args: requests (list[Request]): list of requests containing the context and ending conditions. disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. override_bs (int, optional): Override the batch size for generation. Defaults to None. Returns: list[GenerativeResponse]: list of generated responses. """ for request in requests: request.tokenized_context = self.tok_encode(request.context) dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) results = [] for _ in tqdm( dataset.splits_start_end_iterator(), total=dataset.num_dataset_splits, desc="Splits", position=0, disable=False, # self.disable_tqdm, ): for r in tqdm(dataset, desc="Batch", position=1, disable=False): context = r.context.replace("French phrase: ", "") # TODO: Get src and dest from request translation = self.translator.translate(context, src='fr', dest='de') result = translation.text cur_response = GenerativeResponse( result=result, logits=None, generated_tokens=[], input_tokens=[], ) results.append(cur_response) return dataset.get_original_order(results) @Property def tokenizer(self): return self._tokenizer def tok_encode(self, text: str): return self.tokenizer.encode(text) @Property def add_special_tokens(self) -> bool: return False @Property def max_length(self) -> int: """Return the maximum sequence length of the model.""" return 4096 def loglikelihood( self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None ) -> list[LoglikelihoodResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError def loglikelihood_rolling( self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None ) -> list[LoglikelihoodResponse]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" raise NotImplementedError def loglikelihood_single_token( self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None ) -> list[LoglikelihoodSingleTokenResponse]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. """ raise NotImplementedError ```
1 parent dfeb234 commit e67ed9c

File tree

10 files changed

+869
-0
lines changed

10 files changed

+869
-0
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
title: Add a custom task
1616
- local: adding-a-new-metric
1717
title: Add a custom metric
18+
- local: evaluating-a-custom-model
19+
title: Evaluate a custom model
1820
- local: use-inference-providers-as-backend
1921
title: Use HF's inference providers as backend
2022
- local: use-litellm-as-backend
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Evaluating a Custom Model
2+
3+
Lighteval allows you to evaluate custom model implementations by creating a custom model class that inherits from `LightevalModel`. This is useful when you want to evaluate models that aren't directly supported by the standard backends (transformers, vllm, etc).
4+
5+
## Creating a Custom Model
6+
7+
1. Create a Python file containing your custom model implementation. The model must inherit from `LightevalModel` and implement all required methods.
8+
9+
Here's a basic example:
10+
11+
```python
12+
from lighteval.models.abstract_model import LightevalModel
13+
14+
class MyCustomModel(LightevalModel):
15+
def __init__(self, config):
16+
super().__init__(config)
17+
# Initialize your model here...
18+
19+
def greedy_until(self, requests, max_tokens=None, stop_sequences=None):
20+
# Implement generation logic
21+
pass
22+
23+
def loglikelihood(self, requests, log=True):
24+
# Implement loglikelihood computation
25+
pass
26+
27+
def loglikelihood_rolling(self, requests):
28+
# Implement rolling loglikelihood computation
29+
pass
30+
31+
def loglikelihood_single_token(self, requests):
32+
# Implement single token loglikelihood computation
33+
pass
34+
```
35+
36+
2. The custom model file should contain exactly one class that inherits from `LightevalModel`. This class will be automatically detected and instantiated when loading the model.
37+
38+
> [!TIP]
39+
> You can find a complete example of a custom model implementation in `examples/custom_models/google_translate_model.py`.
40+
41+
## Running the Evaluation
42+
43+
You can evaluate your custom model using either the command line interface or the Python API.
44+
45+
### Using the Command Line
46+
47+
```bash
48+
lighteval custom \
49+
"google-translate" \
50+
"examples/custom_models/google_translate_model.py" \
51+
"lighteval|wmt20:fr-de|0|0" \
52+
--max-samples 10
53+
```
54+
55+
The command takes three required arguments:
56+
- The model name (used for tracking in results/logs)
57+
- The path to your model implementation file
58+
- The tasks to evaluate on (same format as other backends)
59+
60+
### Using the Python API
61+
62+
```python
63+
from lighteval.logging.evaluation_tracker import EvaluationTracker
64+
from lighteval.models.custom.custom_model import CustomModelConfig
65+
from lighteval.pipeline import Pipeline, PipelineParameters, EnvConfig
66+
67+
# Set up evaluation tracking
68+
evaluation_tracker = EvaluationTracker(
69+
output_dir="results",
70+
save_details=True
71+
)
72+
73+
# Configure the pipeline
74+
pipeline_params = PipelineParameters(
75+
launcher_type=ParallelismManager.CUSTOM,
76+
)
77+
78+
# Configure your custom model
79+
model_config = CustomModelConfig(
80+
model="my-custom-model",
81+
model_definition_file_path="path/to/my_model.py"
82+
)
83+
84+
# Create and run the pipeline
85+
pipeline = Pipeline(
86+
tasks="leaderboard|truthfulqa:mc|0|0",
87+
pipeline_parameters=pipeline_params,
88+
evaluation_tracker=evaluation_tracker,
89+
model_config=model_config
90+
)
91+
92+
pipeline.evaluate()
93+
pipeline.save_and_push_results()
94+
```
95+
96+
## Required Methods
97+
98+
Your custom model must implement these core methods:
99+
100+
- `greedy_until`: For generating text until a stop sequence or max tokens is reached
101+
- `loglikelihood`: For computing log probabilities of specific continuations
102+
- `loglikelihood_rolling`: For computing rolling log probabilities of sequences
103+
- `loglikelihood_single_token`: For computing log probabilities of single tokens
104+
105+
See the `LightevalModel` base class documentation for detailed method signatures and requirements.
106+
107+
## Best Practices
108+
109+
1. **Error Handling**: Implement robust error handling in your model methods to gracefully handle edge cases.
110+
111+
2. **Batching**: Consider implementing efficient batching in your model methods to improve performance.
112+
113+
3. **Resource Management**: Properly manage any resources (e.g., API connections, model weights) in your model's `__init__` and `__del__` methods.
114+
115+
4. **Documentation**: Add clear docstrings to your model class and methods explaining any specific requirements or limitations.
116+
117+
## Example Use Cases
118+
119+
Custom models are particularly useful for:
120+
121+
- Evaluating models accessed through custom APIs
122+
- Wrapping models with specialized preprocessing/postprocessing
123+
- Testing novel model architectures
124+
- Evaluating ensemble models
125+
- Integrating with external services or tools
126+
127+
For a complete example of a custom model that wraps the Google Translate API, see `examples/custom_models/google_translate_model.py`.

docs/source/package_reference/models.mdx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
[[autodoc]] models.endpoints.tgi_model.TGIModelConfig
2929
[[autodoc]] models.endpoints.tgi_model.ModelClient
3030

31+
### Custom Model
32+
[[autodoc]] models.custom.custom_model.CustomModelConfig
33+
3134
### Open AI Models
3235
[[autodoc]] models.endpoints.openai_model.OpenAIClient
3336

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
import hashlib
24+
import logging
25+
import os
26+
import time
27+
28+
import diskcache
29+
import tenacity
30+
from deep_translator import GoogleTranslator
31+
from tqdm import tqdm
32+
from transformers import AutoTokenizer
33+
34+
from lighteval.data import GenerativeTaskDataset
35+
from lighteval.models.abstract_model import LightevalModel, ModelInfo
36+
from lighteval.models.model_output import (
37+
GenerativeResponse,
38+
LoglikelihoodResponse,
39+
LoglikelihoodSingleTokenResponse,
40+
)
41+
from lighteval.tasks.requests import (
42+
GreedyUntilRequest,
43+
LoglikelihoodRequest,
44+
LoglikelihoodRollingRequest,
45+
LoglikelihoodSingleTokenRequest,
46+
)
47+
48+
49+
logger = logging.getLogger(__name__)
50+
51+
52+
class GoogleTranslateClient(LightevalModel):
53+
def __init__(self, config) -> None:
54+
self.model = config.model_name
55+
self.model_definition_file_path = config.model_definition_file_path
56+
57+
self.model_info = ModelInfo(
58+
model_name=config.model,
59+
model_sha="",
60+
model_dtype=None,
61+
model_size="",
62+
)
63+
64+
self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility
65+
66+
# Deep-translator also supports other translators
67+
self.translator = GoogleTranslator()
68+
69+
# Initialize disk cache
70+
cache_dir = os.path.join(os.getcwd(), ".translation_cache")
71+
self.cache = diskcache.Cache(cache_dir)
72+
73+
self.max_retries = 3
74+
self.retry_delay = 1
75+
76+
def _get_cache_key(self, context: str, src_lang: str, tgt_lang: str) -> str:
77+
"""Generate a unique cache key for the translation request."""
78+
# IMPORTANT: In case we want to support other translators, we can add the translator name to the key
79+
key_string = f"{context}|{src_lang}|{tgt_lang}"
80+
return hashlib.md5(key_string.encode()).hexdigest()
81+
82+
@tenacity.retry(
83+
stop=tenacity.stop_after_attempt(3),
84+
wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
85+
retry=tenacity.retry_if_exception_type((Exception)),
86+
before_sleep=lambda retry_state: time.sleep(1),
87+
)
88+
def _translate_with_cache(self, context: str, src_lang: str, tgt_lang: str) -> str:
89+
"""Translate text using cache if available, otherwise call Google Translate with retry logic."""
90+
cache_key = self._get_cache_key(context, src_lang, tgt_lang)
91+
92+
# Try to get from cache
93+
if cache_key in self.cache:
94+
result = self.cache[cache_key]
95+
if result is not None and result != "":
96+
return result
97+
logger.warning("Translation in cache is empty. Removing from cache and retrying...")
98+
del self.cache[cache_key]
99+
100+
try:
101+
# Updated translation call for deep-translator
102+
self.translator.source = src_lang
103+
self.translator.target = tgt_lang
104+
result = self.translator.translate(context)
105+
if result is None or result == "":
106+
result = ""
107+
108+
self.cache[cache_key] = result
109+
return result
110+
except Exception as e:
111+
logger.warning(f"Translation error: {str(e)}. Retrying...")
112+
raise # Let tenacity handle the retry
113+
114+
def greedy_until(
115+
self,
116+
requests: list[GreedyUntilRequest],
117+
) -> list[GenerativeResponse]:
118+
"""
119+
Generates responses using a greedy decoding strategy until certain ending conditions are met.
120+
Results are cached to disk to avoid repeated translations.
121+
122+
Args:
123+
requests (list[Request]): list of requests containing the context and ending conditions.
124+
override_bs (int, optional): Override the batch size for generation. Defaults to None.
125+
126+
Returns:
127+
list[GenerativeResponse]: list of generated responses.
128+
"""
129+
for request in requests:
130+
request.tokenized_context = self.tok_encode(request.context)
131+
132+
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
133+
results = []
134+
135+
for _ in tqdm(
136+
dataset.splits_start_end_iterator(),
137+
total=dataset.num_dataset_splits,
138+
desc="Splits",
139+
position=0,
140+
disable=False, # self.disable_tqdm,
141+
):
142+
for r in tqdm(dataset, desc="Batch", position=1, disable=False):
143+
# Extract source and target languages from task name
144+
# Format is like "community|sdst-text_level:de-fr|0"
145+
src_lang, tgt_lang = r.task_name.split("|")[1].split(":")[-1].split("-")
146+
147+
context = r.context.replace(f"{src_lang.upper()}: ", "").replace(f"\n{tgt_lang.upper()}: ", "")
148+
result = self._translate_with_cache(context, src_lang, tgt_lang)
149+
if result is None:
150+
result = "" # Set to empty string to prevent errors in metric computation
151+
152+
cur_response = GenerativeResponse(
153+
result=result,
154+
logits=None,
155+
generated_tokens=[],
156+
input_tokens=[],
157+
)
158+
results.append(cur_response)
159+
160+
return dataset.get_original_order(results)
161+
162+
@property
163+
def tokenizer(self):
164+
return self._tokenizer
165+
166+
def tok_encode(self, text: str):
167+
return text
168+
169+
@property
170+
def add_special_tokens(self) -> bool:
171+
return False
172+
173+
@property
174+
def max_length(self) -> int:
175+
"""Return the maximum sequence length of the model."""
176+
return 4096
177+
178+
def loglikelihood(self, requests: list[LoglikelihoodRequest]) -> list[LoglikelihoodResponse]:
179+
"""Tokenize the context and continuation and compute the log likelihood of those
180+
tokenized sequences.
181+
"""
182+
raise NotImplementedError
183+
184+
def loglikelihood_rolling(
185+
self,
186+
requests: list[LoglikelihoodRollingRequest],
187+
) -> list[LoglikelihoodResponse]:
188+
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
189+
raise NotImplementedError
190+
191+
def loglikelihood_single_token(
192+
self,
193+
requests: list[LoglikelihoodSingleTokenRequest],
194+
) -> list[LoglikelihoodSingleTokenResponse]:
195+
"""Tokenize the context and continuation and compute the log likelihood of those
196+
tokenized sequences.
197+
"""
198+
raise NotImplementedError

0 commit comments

Comments
 (0)