Skip to content

Commit 5c79e2b

Browse files
committed
Final fixes
1 parent e2fe723 commit 5c79e2b

File tree

7 files changed

+86
-91
lines changed

7 files changed

+86
-91
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ dependencies = [
8585
"httpx>=0.27.2",
8686
"latex2sympy2_extended==1.0.6",
8787
"langcodes",
88-
"sglang"
8988
]
9089

9190
[project.optional-dependencies]
@@ -100,6 +99,7 @@ nanotron = [
10099
]
101100
tensorboardX = ["tensorboardX"]
102101
vllm = ["vllm>=0.10.0,<0.10.2", "ray", "more_itertools"]
102+
sglang = ["sglang"]
103103
quality = ["ruff>=v0.11.0","pre-commit"]
104104
tests = ["pytest>=7.4.0","deepdiff","pip>=25.2"]
105105
dev = ["lighteval[accelerate,quality,tests,multilingual,math,extended_tasks,vllm]"]

src/lighteval/models/endpoints/tgi_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from lighteval.utils.imports import Extra, is_package_available, requires
3636

3737

38-
if is_package_available("tgi"):
38+
if is_package_available(Extra.TGI):
3939
from text_generation import AsyncClient
4040
else:
4141
from unittest.mock import Mock
@@ -99,7 +99,6 @@ class TGIModelConfig(ModelConfig):
9999

100100
# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite
101101
# the client functions, since they use a different client.
102-
@requires(Extra.TGI)
103102
class ModelClient(InferenceEndpointModel):
104103
_DEFAULT_MAX_LENGTH: int = 4096
105104

@@ -134,6 +133,7 @@ def __init__(self, config: TGIModelConfig) -> None:
134133
# Initialize cache for tokenization and predictions
135134
self._cache = SampleCache(config)
136135

136+
@requires(Extra.TGI)
137137
def _async_process_request(
138138
self,
139139
context: str,
@@ -173,6 +173,7 @@ def _async_process_request(
173173

174174
return generated_text
175175

176+
@requires(Extra.TGI)
176177
def _process_request(self, *args, **kwargs) -> TextGenerationOutput:
177178
return asyncio.run(self._async_process_request(*args, **kwargs))
178179

src/lighteval/models/sglang/sglang_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ class SGLangModelConfig(ModelConfig):
138138
override_chat_template: bool = None
139139

140140

141-
@requires("sglang")
142141
class SGLangModel(LightevalModel):
143142
def __init__(
144143
self,
@@ -187,7 +186,7 @@ def add_special_tokens(self):
187186
def max_length(self) -> int:
188187
return self._max_length
189188

190-
def _create_auto_model(self, config: SGLangModelConfig) -> Optional[Engine]:
189+
def _create_auto_model(self, config: SGLangModelConfig) -> Optional["Engine"]:
191190
self.model_args = {
192191
"model_path": config.model_name,
193192
"trust_remote_code": config.trust_remote_code,
@@ -314,6 +313,7 @@ def _greedy_until(
314313
results.append(cur_response)
315314
return dataset.get_original_order(results)
316315

316+
@requires("sglang")
317317
def _generate(
318318
self,
319319
inputs: list[list[int]],

src/lighteval/utils/imports.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def is_package_available(package: str | Requirement | Extra):
4141

4242
# If the package is a string, it will get the potential required version from the pyproject.toml
4343
if isinstance(package, str):
44+
if package not in deps:
45+
raise RuntimeError(
46+
f"Package {package} was tested against, but isn't specified in the pyproject.toml file. Please specify"
47+
f"it as a potential dependency or an extra for it to be checked."
48+
)
4449
package = deps[package]
4550

4651
# If the specified package is an "Extra", we will iterate through each required dependency of that extra
@@ -75,6 +80,10 @@ def required_dependencies() -> Tuple[Dict[str, Requirement], Dict[str, List[Requ
7580
extra = None
7681
if ";" in dep:
7782
dep, marker = dep.split(";", 1)
83+
84+
# The `metadata` function prints requirements as follows
85+
# 'vllm<0.10.2,>=0.10.0; extra == "vllm"'
86+
# The regex searches for "extra == <MARKER>" in order to parse the marker.
7887
match = re.search(r'extra\s*==\s*"(.*?)"', marker)
7988
extra = match.group(1) if match else None
8089
requirement = Requirement(dep.strip())
@@ -146,30 +155,38 @@ def __getattribute__(cls, key):
146155
for backend in cls._backends:
147156
raise_if_package_not_available(backend)
148157

158+
return super().__getattribute__(key)
149159

150-
def requires(*backends):
151-
"""
152-
Decorator to raise an ImportError if the decorated object (function or class) requires a dependency
153-
which is not installed.
154-
"""
155160

161+
def parse_specified_backends(specified_backends):
156162
requirements, _ = required_dependencies()
157-
158163
applied_backends = []
159-
for backend in backends:
164+
165+
for backend in specified_backends:
160166
if isinstance(backend, Extra):
161-
applied_backends.append(backend)
167+
applied_backends.append(backend if isinstance(backend, Extra) else requirements[backend])
168+
elif backend not in requirements:
169+
raise RuntimeError(
170+
"A dependency was specified with @requires, but it is not defined in the possible dependencies "
171+
f"defined in the pyproject.toml: `{backend}`."
172+
f""
173+
f"If editing the pyproject.toml to add a new dependency, remember to reinstall lighteval for the"
174+
f"update to take effect."
175+
)
162176
else:
163-
if backend not in requirements:
164-
raise RuntimeError(
165-
"A dependency was specified with @requires, but it is not defined in the possible dependencies "
166-
f"defined in the pyproject.toml: `{backend}`."
167-
f""
168-
f"If editing the pyproject.toml to add a new dependency, remember to reinstall lighteval for the"
169-
f"update to take effect."
170-
)
171177
applied_backends.append(requirements[backend])
172178

179+
return applied_backends
180+
181+
182+
def requires(*specified_backends):
183+
"""
184+
Decorator to raise an ImportError if the decorated object (function or class) requires a dependency
185+
which is not installed.
186+
"""
187+
188+
applied_backends = parse_specified_backends(specified_backends)
189+
173190
def inner_fn(_object):
174191
_object._backends = applied_backends
175192

@@ -185,7 +202,7 @@ def __init__(self, *args, **kwargs):
185202
Placeholder.__name__ = _object.__name__
186203
Placeholder.__module__ = _object.__module__
187204

188-
return Placeholder
205+
return _object if all(is_package_available(backend) for backend in applied_backends) else Placeholder
189206
else:
190207

191208
@functools.wraps(_object)

tests/test_dependencies.py

Lines changed: 40 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
# MIT License
2-
2+
#
33
# Copyright (c) 2024 The HuggingFace Team
4-
4+
#
55
# Permission is hereby granted, free of charge, to any person obtaining a copy
66
# of this software and associated documentation files (the "Software"), to deal
77
# in the Software without restriction, including without limitation the rights
88
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
99
# copies of the Software, and to permit persons to whom the Software is
1010
# furnished to do so, subject to the following conditions:
11-
11+
#
1212
# The above copyright notice and this permission notice shall be included in all
1313
# copies or substantial portions of the Software.
14-
14+
#
1515
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1616
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1717
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -20,75 +20,60 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23-
# tests/utils/pretend_missing.py
24-
import functools
25-
import importlib
26-
2723
import pytest
28-
from packaging.requirements import Requirement
29-
30-
import lighteval.utils.imports as imports
31-
32-
33-
def pretend_missing(*names):
34-
"""
35-
Decorator: pretend that certain packages are missing
36-
by patching mypkg.utils.is_package_available.
37-
"""
38-
39-
def decorator(test_func):
40-
@functools.wraps(test_func)
41-
def wrapper(*args, **kwargs):
42-
from unittest.mock import patch
43-
44-
def fake(requirement):
45-
name = requirement.name if isinstance(requirement, Requirement) else requirement
46-
return False if name in names else (importlib.util.find_spec(name) is not None)
4724

48-
with patch.object(imports, "is_package_available", side_effect=fake):
49-
# If your module caches results at import time, reload here
50-
import lighteval
25+
from lighteval.utils.imports import Extra, is_package_available, requires
5126

52-
importlib.reload(lighteval)
5327

54-
return test_func(*args, **kwargs)
28+
def test_requires():
29+
@requires("sglang")
30+
class RandomModel:
31+
pass
5532

56-
return wrapper
33+
assert RandomModel.__name__ == "RandomModel"
34+
assert RandomModel.__class__.__name__ == "DummyObject"
5735

58-
return decorator
36+
with pytest.raises(
37+
ImportError,
38+
match="Through the use of RandomModel, you requested the use of `sglang` for this evaluation, but it is not available in your current environment. Please install it using pip.",
39+
):
40+
RandomModel()
5941

6042

61-
@pretend_missing("langdetect")
62-
def test_langdetect_required_for_ifeval():
63-
from lighteval.main_accelerate import accelerate
43+
def test_requires_with_extra():
44+
@requires(Extra.TGI)
45+
class RandomModel:
46+
pass
6447

6548
with pytest.raises(
6649
ImportError,
67-
match="Through the use of ifeval_prompt, you requested the use of `langdetect` for this evaluation, but it is not available in your current environment. Please install it using pip.",
50+
match=r"Through the use of RandomModel, you are trying to run an evaluation requiring tgi capabilities. Please install the required extra: `pip install lighteval\[tgi\]`",
6851
):
69-
accelerate(model_args="model_name=gpt2,batch_size=1", tasks="extended|ifeval|0", max_samples=0)
52+
RandomModel()
7053

7154

72-
@pretend_missing("spacy", "stanza")
73-
def test_multilingual_required_for_xnli():
74-
"""
75-
This checks that the Extra.MULTILINGUAL correctly raises if there are missing dependencies.
76-
"""
77-
from lighteval.main_accelerate import accelerate
78-
55+
def test_requires_with_wrong_dependency():
7956
with pytest.raises(
80-
ImportError,
81-
match="Through the use of get_multilingual_normalizer, you are trying to run an evaluation requiring multilingual capabilities.",
57+
RuntimeError,
58+
match="A dependency was specified with @requires, but it is not defined in the possible dependencies defined in the pyproject.toml: `random_dependency`",
8259
):
83-
accelerate(model_args="model_name=gpt2,batch_size=1", tasks="multilingual|xnli_zho_mcf|0", max_samples=0)
8460

61+
@requires("random_dependency")
62+
class RandomModel:
63+
pass
8564

86-
@pretend_missing("vllm")
87-
def test_vllm_required_for_vllm_usage():
88-
from lighteval.main_vllm import vllm
8965

66+
def test_is_package_available():
67+
assert is_package_available("torch")
68+
69+
70+
def test_is_package_unavailable():
71+
assert not is_package_available("tensorboardX")
72+
73+
74+
def test_is_package_is_not_specified_in_pyproject_toml():
9075
with pytest.raises(
91-
ImportError,
92-
match="Through the use of VLLMModel, you requested the use of `vllm<0.10.2,>=0.10.0` for this evaluation, but it is not available in your current environment. Please install it using pip.",
76+
RuntimeError,
77+
match="Package tensorflow was tested against, but isn't specified in the pyproject.toml file. Please specifyit as a potential dependency or an extra for it to be checked.",
9378
):
94-
vllm(model_args="model_name=gpt2", tasks="lighteval|aime24|0", max_samples=0)
79+
is_package_available("tensorflow")

tests/unit/models/test_transformers_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,9 @@ def mock_gather(tensor):
398398
class TestTransformersModelUseChatTemplate(unittest.TestCase):
399399
@patch("lighteval.models.transformers.transformers_model.Accelerator")
400400
@patch("lighteval.models.transformers.transformers_model.TransformersModel._create_auto_model")
401-
@patch("lighteval.utils.imports.is_accelerate_available")
401+
@patch("lighteval.utils.imports.is_package_available")
402402
def test_transformers_model_use_chat_template_with_different_model_names(
403-
self, mock_accelerator, mock_create_model, is_accelerate_available
403+
self, mock_accelerator, mock_create_model, is_package_available
404404
):
405405
"""Test that TransformersModel correctly determines whether to use_chat_template or not automatically from the tokenizer config."""
406406
test_cases = [

tests/unit/utils/test_caching.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from lighteval.models.model_output import ModelResponse
3333
from lighteval.tasks.requests import Doc, SamplingMethod
3434
from lighteval.utils.cache_management import SampleCache
35+
from lighteval.utils.imports import Extra, is_package_available
3536

3637

3738
class TestCaching(unittest.TestCase):
@@ -177,12 +178,9 @@ def _test_cache(self, model: LightevalModel, test_cases):
177178

178179
@patch("lighteval.models.transformers.transformers_model.TransformersModel._loglikelihood_tokens")
179180
@patch("lighteval.models.transformers.transformers_model.TransformersModel._padded_greedy_until")
180-
@patch("lighteval.utils.imports.is_accelerate_available")
181181
@patch("lighteval.models.transformers.transformers_model.Accelerator")
182182
@patch("lighteval.models.transformers.transformers_model.TransformersModel._create_auto_model")
183-
def test_cache_transformers(
184-
self, mock_create_model, mock_accelerator, mock_is_accelerate_available, mock_greedy_until, mock_loglikelihood
185-
):
183+
def test_cache_transformers(self, mock_create_model, mock_accelerator, mock_greedy_until, mock_loglikelihood):
186184
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig
187185

188186
# Skip the model creation phase
@@ -192,7 +190,6 @@ def test_cache_transformers(
192190
mock_accelerator_instance = Mock()
193191
mock_accelerator_instance.device = torch.device("cpu")
194192
mock_accelerator.return_value = mock_accelerator_instance
195-
mock_is_accelerate_available = False # noqa F841
196193

197194
mock_greedy_until.return_value = self.model_responses
198195
mock_loglikelihood.return_value = self.model_responses
@@ -237,9 +234,8 @@ def test_cache_vllm(self, mock_create_model, mock_greedy_until, mock_loglikeliho
237234
@patch("lighteval.models.endpoints.tgi_model.ModelClient._loglikelihood")
238235
def test_cache_tgi(self, mock_loglikelihood, mock_greedy_until, mock_requests_get):
239236
from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig
240-
from lighteval.utils.imports import is_package_available
241237

242-
if not is_package_available("tgi"):
238+
if not is_package_available(Extra.TGI):
243239
pytest.skip("Skipping because missing the imports")
244240

245241
# Mock TGI requests
@@ -320,12 +316,9 @@ def test_cache_sglang(
320316
)
321317

322318
@patch("lighteval.models.transformers.vlm_transformers_model.VLMTransformersModel._greedy_until")
323-
@patch("lighteval.utils.imports.is_accelerate_available")
324319
@patch("lighteval.models.transformers.vlm_transformers_model.Accelerator")
325320
@patch("lighteval.models.transformers.vlm_transformers_model.VLMTransformersModel._create_auto_model")
326-
def test_cache_vlm_transformers(
327-
self, mock_create_model, mock_accelerator, is_accelerate_available, mock_greedy_until
328-
):
321+
def test_cache_vlm_transformers(self, mock_create_model, mock_accelerator, mock_greedy_until):
329322
from lighteval.models.transformers.vlm_transformers_model import (
330323
VLMTransformersModel,
331324
VLMTransformersModelConfig,
@@ -335,7 +328,6 @@ def test_cache_vlm_transformers(
335328
mock_accelerator_instance = Mock()
336329
mock_accelerator_instance.device = torch.device("cpu")
337330
mock_accelerator.return_value = mock_accelerator_instance
338-
is_accelerate_available = False # noqa F841
339331

340332
# Skip the model creation phase
341333
mock_create_model = Mock() # noqa F841

0 commit comments

Comments
 (0)