Skip to content

Commit df3a82d

Browse files
authored
Caching samples PR (#909)
Adds a new caching system for generative evals, plus test suite, plus doc - the system loads indices first, then runs samples as needed, then lastly loads the cached items as needed. (We don't keep the cache in mem when running models). Contains a test suite and doc page
1 parent bfa6076 commit df3a82d

23 files changed

+1066
-552
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
- sections:
1010
- local: saving-and-reading-results
1111
title: Save and read results
12+
- local: caching
13+
title: Caching
1214
- local: using-the-python-api
1315
title: Use the Python API
1416
- local: adding-a-custom-task

docs/source/caching.mdx

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Caching System
2+
3+
Lighteval includes a caching system that can significantly speed up evaluations by storing and reusing model predictions.
4+
This is especially useful when running the same evaluation multiple times, or comparing different evaluation metrics on the same model outputs.
5+
6+
## How It Works
7+
8+
The caching system caches the predictions of the model for now (we will add tokenized input caching later).
9+
It stores model responses objects (generations, logits, probabilities) for evaluation samples.
10+
11+
### Cache Structure
12+
13+
Cached data is stored on disk using HuggingFace datasets in the following structure:
14+
15+
```
16+
.cache/
17+
└── huggingface/
18+
└── lighteval/
19+
└── predictions/
20+
└── {model_name}/
21+
└── {model_hash}/
22+
└── {task_name}.parquet
23+
```
24+
25+
Where:
26+
- `model_name`: The model name (path on the hub or local path)
27+
- `model_hash`: Hash of the model configuration to ensure cache invalidation when parameters change
28+
- `task_name`: Name of the evaluation task
29+
30+
### Cache Recreation
31+
32+
A new cache is automatically created when:
33+
- Model configuration changes (different parameters, quantization, etc.)
34+
- Model weights change (different revision, checkpoint, etc.)
35+
- Generation parameters change (temperature, max_tokens, etc.)
36+
37+
This ensures that cached results are always consistent with your current model setup.
38+
39+
## Using Caching
40+
41+
### Automatic Caching
42+
43+
All built-in model classes in Lighteval automatically support caching. No additional configuration is needed.
44+
For custom models you need to add a cache to the model class and decorators on all functions.
45+
46+
## Cache Management
47+
48+
### Clearing Cache
49+
50+
To clear cache for a specific model, delete the corresponding directory:
51+
52+
```bash
53+
rm -rf ./cache/huggingface/lighteval/predictions/{model_name}/{model_hash}/
54+
```
55+
56+
To clear all caches:
57+
58+
```bash
59+
rm -rf ./cache/huggingface/lighteval/predictions
60+
```

docs/source/evaluating-a-custom-model.mdx

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
# Evaluating a Custom Model
1+
# Custom Model
22

3-
Lighteval allows you to evaluate custom model implementations by creating a custom model class that inherits from `LightevalModel`. This is useful when you want to evaluate models that aren't directly supported by the standard backends (transformers, vllm, etc).
3+
Lighteval allows you to evaluate custom model implementations by creating a custom model class that inherits from `LightevalModel`.
4+
This is useful when you want to evaluate models that aren't directly supported by the standard backends and providers (transformers, vllm, etc), or
5+
if you want to add your own pre/post processing.
46

57
## Creating a Custom Model
68

@@ -9,28 +11,34 @@ Lighteval allows you to evaluate custom model implementations by creating a cust
911
Here's a basic example:
1012

1113
```python
14+
from typing import List
1215
from lighteval.models.abstract_model import LightevalModel
16+
from lighteval.models.model_output import ModelResponse
17+
from lighteval.tasks.requests import Doc
18+
from lighteval.utils.cache_management import SampleCache, cached
1319

1420
class MyCustomModel(LightevalModel):
1521
def __init__(self, config):
1622
super().__init__(config)
1723
# Initialize your model here...
1824

19-
def greedy_until(self, requests, max_tokens=None, stop_sequences=None):
25+
# Enable caching (recommended)
26+
self._cache = SampleCache(config)
27+
28+
@cached("predictions") # Enable caching for better performance
29+
def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]:
2030
# Implement generation logic
2131
pass
2232

23-
def loglikelihood(self, requests, log=True):
33+
@cached("predictions") # Enable caching for better performance
34+
def loglikelihood(self, docs: List[Doc]) -> List[ModelResponse]:
2435
# Implement loglikelihood computation
2536
pass
2637

27-
def loglikelihood_rolling(self, requests):
38+
@cached("predictions") # Enable caching for better performance
39+
def loglikelihood_rolling(self, docs: List[Doc]) -> List[ModelResponse]:
2840
# Implement rolling loglikelihood computation
2941
pass
30-
31-
def loglikelihood_single_token(self, requests):
32-
# Implement single token loglikelihood computation
33-
pass
3442
```
3543

3644
2. The custom model file should contain exactly one class that inherits from `LightevalModel`. This class will be automatically detected and instantiated when loading the model.
@@ -97,31 +105,44 @@ pipeline.save_and_push_results()
97105

98106
Your custom model must implement these core methods:
99107

100-
- `greedy_until`: For generating text until a stop sequence or max tokens is reached
101-
- `loglikelihood`: For computing log probabilities of specific continuations
102-
- `loglikelihood_rolling`: For computing rolling log probabilities of sequences
103-
- `loglikelihood_single_token`: For computing log probabilities of single tokens
108+
- `greedy_until`: For generating text until a stop sequence or max tokens is reached - this is used for generative evaluations
109+
- `loglikelihood`: For computing log probabilities of specific continuations - this is used for multiple choice logprob evaluations
110+
- `loglikelihood_rolling`: For computing rolling log probabilities of sequences - this is used for perplexity metrics
104111

105112
See the `LightevalModel` base class documentation for detailed method signatures and requirements.
106113

107-
## Best Practices
114+
## Enabling Caching (Recommended)
108115

109-
1. **Error Handling**: Implement robust error handling in your model methods to gracefully handle edge cases.
116+
Lighteval includes a caching system that can significantly speed up evaluations by storing and reusing model predictions.
117+
To enable caching in your custom model:
110118

111-
2. **Batching**: Consider implementing efficient batching in your model methods to improve performance.
119+
1. **Import caching components**:
120+
```python
121+
from lighteval.utils.cache_management import SampleCache, cached
122+
```
112123

113-
3. **Resource Management**: Properly manage any resources (e.g., API connections, model weights) in your model's `__init__` and `__del__` methods.
124+
2. **Initialize cache in constructor**:
125+
```python
126+
def __init__(self, config):
127+
# Your initialization code...
128+
self._cache = SampleCache(config)
129+
```
114130

115-
4. **Documentation**: Add clear docstrings to your model class and methods explaining any specific requirements or limitations.
131+
3. **Add cache decorators** to your prediction methods:
132+
```python
133+
@cached("predictions")
134+
def greedy_until(self, docs: List[Doc]) -> List[ModelResponse]:
135+
# Your implementation...
136+
```
116137

117-
## Example Use Cases
138+
For detailed information about the caching system, see the [Caching Documentation](./caching.mdx).
118139

119-
Custom models are particularly useful for:
140+
## Best Practices
141+
142+
1. **Error Handling**: Implement robust error handling in your model methods to gracefully handle edge cases.
143+
144+
2. **Batching**: Consider implementing efficient batching in your model methods to improve performance.
120145

121-
- Evaluating models accessed through custom APIs
122-
- Wrapping models with specialized preprocessing/postprocessing
123-
- Testing novel model architectures
124-
- Evaluating ensemble models
125-
- Integrating with external services or tools
146+
3. **Documentation**: Add clear docstrings to your model class and methods explaining any specific requirements or limitations.
126147

127-
For a complete example of a custom model that wraps the Google Translate API, see `examples/custom_models/google_translate_model.py`.
148+
4. **Caching**: Enable caching to speed up repeated evaluations and development iterations.

examples/custom_models/google_translate_model.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,10 @@
3434
from lighteval.data import GenerativeTaskDataset
3535
from lighteval.models.abstract_model import LightevalModel
3636
from lighteval.models.model_output import (
37-
GenerativeResponse,
38-
LoglikelihoodResponse,
39-
LoglikelihoodSingleTokenResponse,
37+
ModelResponse,
4038
)
4139
from lighteval.tasks.requests import (
42-
GreedyUntilRequest,
43-
LoglikelihoodRequest,
44-
LoglikelihoodRollingRequest,
45-
LoglikelihoodSingleTokenRequest,
40+
Doc,
4641
)
4742

4843

@@ -107,8 +102,8 @@ def _translate_with_cache(self, context: str, src_lang: str, tgt_lang: str) -> s
107102

108103
def greedy_until(
109104
self,
110-
requests: list[GreedyUntilRequest],
111-
) -> list[GenerativeResponse]:
105+
requests: list[Doc],
106+
) -> list[ModelResponse]:
112107
"""
113108
Generates responses using a greedy decoding strategy until certain ending conditions are met.
114109
Results are cached to disk to avoid repeated translations.
@@ -118,7 +113,7 @@ def greedy_until(
118113
override_bs (int, optional): Override the batch size for generation. Defaults to None.
119114
120115
Returns:
121-
list[GenerativeResponse]: list of generated responses.
116+
list[ModelResponse]: list of generated responses.
122117
"""
123118
for request in requests:
124119
request.tokenized_context = self.tok_encode(request.context)
@@ -143,7 +138,7 @@ def greedy_until(
143138
if result is None:
144139
result = "" # Set to empty string to prevent errors in metric computation
145140

146-
cur_response = GenerativeResponse(
141+
cur_response = ModelResponse(
147142
result=result,
148143
logits=None,
149144
generated_tokens=[],
@@ -169,24 +164,15 @@ def max_length(self) -> int:
169164
"""Return the maximum sequence length of the model."""
170165
return 4096
171166

172-
def loglikelihood(self, requests: list[LoglikelihoodRequest]) -> list[LoglikelihoodResponse]:
167+
def loglikelihood(self, requests: list[Doc]) -> list[ModelResponse]:
173168
"""Tokenize the context and continuation and compute the log likelihood of those
174169
tokenized sequences.
175170
"""
176171
raise NotImplementedError
177172

178173
def loglikelihood_rolling(
179174
self,
180-
requests: list[LoglikelihoodRollingRequest],
181-
) -> list[LoglikelihoodResponse]:
175+
requests: list[Doc],
176+
) -> list[ModelResponse]:
182177
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
183178
raise NotImplementedError
184-
185-
def loglikelihood_single_token(
186-
self,
187-
requests: list[LoglikelihoodSingleTokenRequest],
188-
) -> list[LoglikelihoodSingleTokenResponse]:
189-
"""Tokenize the context and continuation and compute the log likelihood of those
190-
tokenized sequences.
191-
"""
192-
raise NotImplementedError

examples/custom_models/local_mt_model.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,10 @@
3636
from lighteval.data import GenerativeTaskDataset
3737
from lighteval.models.abstract_model import LightevalModel, TokenSequence
3838
from lighteval.models.model_output import (
39-
GenerativeResponse,
40-
LoglikelihoodResponse,
41-
LoglikelihoodSingleTokenResponse,
39+
ModelResponse,
4240
)
4341
from lighteval.tasks.requests import (
44-
GreedyUntilRequest,
45-
LoglikelihoodRequest,
46-
LoglikelihoodRollingRequest,
47-
LoglikelihoodSingleTokenRequest,
42+
Doc,
4843
)
4944

5045

@@ -119,9 +114,9 @@ def _convert_to_iso3(self, lang_code: str) -> str:
119114

120115
def greedy_until(
121116
self,
122-
requests: list[GreedyUntilRequest],
117+
requests: list[Doc],
123118
override_bs: Optional[int] = None,
124-
) -> list[GenerativeResponse]:
119+
) -> list[ModelResponse]:
125120
"""
126121
Generates responses using a greedy decoding strategy until certain ending conditions are met.
127122
Results are cached to disk to avoid repeated translations.
@@ -131,7 +126,7 @@ def greedy_until(
131126
override_bs (int, optional): Override the batch size for generation. Defaults to None.
132127
133128
Returns:
134-
list[GenerativeResponse]: list of generated responses.
129+
list[ModelResponse]: list of generated responses.
135130
"""
136131

137132
def get_langs(task_name: str) -> tuple[str, str]:
@@ -204,7 +199,7 @@ def get_langs(task_name: str) -> tuple[str, str]:
204199
# Create responses for the batch
205200
for input_tokens, output_tokens, translation in zip(input_ids, output_ids, translations):
206201
results.append(
207-
GenerativeResponse(
202+
ModelResponse(
208203
input_tokens=input_tokens,
209204
generated_tokens=output_tokens,
210205
result=translation,
@@ -256,24 +251,12 @@ def max_length(self) -> int:
256251
"""Return the maximum sequence length of the model."""
257252
return 4096
258253

259-
def loglikelihood(
260-
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
261-
) -> list[LoglikelihoodResponse]:
254+
def loglikelihood(self, requests: list[Doc], override_bs: Optional[int] = None) -> list[ModelResponse]:
262255
"""Tokenize the context and continuation and compute the log likelihood of those
263256
tokenized sequences.
264257
"""
265258
raise NotImplementedError
266259

267-
def loglikelihood_rolling(
268-
self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None
269-
) -> list[LoglikelihoodResponse]:
260+
def loglikelihood_rolling(self, requests: list[Doc], override_bs: Optional[int] = None) -> list[ModelResponse]:
270261
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
271262
raise NotImplementedError
272-
273-
def loglikelihood_single_token(
274-
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
275-
) -> list[LoglikelihoodSingleTokenResponse]:
276-
"""Tokenize the context and continuation and compute the log likelihood of those
277-
tokenized sequences.
278-
"""
279-
raise NotImplementedError

src/lighteval/main_accelerate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,10 @@ def accelerate( # noqa C901
154154
config: dict = ModelConfig._parse_args(model_args)
155155

156156
if config.get("delta_weights", False):
157+
config.pop("delta_weights")
157158
model_config = DeltaModelConfig(**config)
158159
elif config.get("adapter_weights", False):
160+
config.pop("adapter_weights")
159161
model_config = AdapterModelConfig(**config)
160162
else:
161163
if vision_model:

src/lighteval/models/abstract_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class ModelConfig(BaseModel, extra="forbid"):
8181

8282
generation_parameters: GenerationParameters = GenerationParameters()
8383
system_prompt: str | None = None
84+
cache_dir: str = "./cache/huggingface/lighteval"
8485

8586
@classmethod
8687
def from_path(cls, path: str):
@@ -191,7 +192,7 @@ def greedy_until(
191192
docs (list[Doc]): List of documents containing the context for generation.
192193
193194
Returns:
194-
list[GenerativeResponse]: list of generated responses.
195+
list[ModelResponse]: list of generated responses.
195196
"""
196197
return NotImplemented
197198

0 commit comments

Comments
 (0)