Skip to content

Commit 2c6f61a

Browse files
committed
Update
1 parent ac2d18e commit 2c6f61a

File tree

10 files changed

+106
-64
lines changed

10 files changed

+106
-64
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ dependencies = [
8484
"fsspec>=2023.12.2",
8585
"httpx>=0.27.2",
8686
"latex2sympy2_extended==1.0.6",
87-
"langcodes"
87+
"langcodes",
88+
"sglang"
8889
]
8990

9091
[project.optional-dependencies]

src/lighteval/metrics/imports/data_stats_metric.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from typing import Literal
3131

3232
from lighteval.metrics.imports.data_stats_utils import Fragments
33-
from lighteval.utils.imports import Extras, raise_if_package_not_available, requires
33+
from lighteval.utils.imports import Extra, requires
3434

3535

3636
logger = logging.getLogger(__name__)
@@ -55,7 +55,7 @@ def find_ngrams(input_list, n):
5555
return zip(*[input_list[i:] for i in range(n)])
5656

5757

58-
@requires(Extras.MULTILINGUAL)
58+
@requires(Extra.MULTILINGUAL)
5959
class DataStatsMetric(Metric):
6060
def __init__(
6161
self,
@@ -87,7 +87,6 @@ def __init__(
8787
determines the spaCy model used for tokenization. Currently supports English,
8888
German, French, and Italian.
8989
"""
90-
raise_if_package_not_available("spacy")
9190
import spacy
9291

9392
self.n_gram = n_gram

src/lighteval/metrics/normalizations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from typing import Callable
2929

3030
from lighteval.metrics.utils.linguistic_tokenizers import get_word_tokenizer
31-
from lighteval.utils.imports import Extras, requires
31+
from lighteval.utils.imports import Extra, requires
3232
from lighteval.utils.language import Language
3333

3434

@@ -445,7 +445,7 @@ def remove_punc(text: str) -> str:
445445
return "".join(ch for ch in text if ch not in PUNCT)
446446

447447

448-
@requires(Extras.MULTILINGUAL)
448+
@requires(Extra.MULTILINGUAL)
449449
def get_multilingual_normalizer(lang: Language, lower: bool = True) -> Callable[[str], str]:
450450
"""Get a normalizer function for the specified language.
451451

src/lighteval/metrics/utils/linguistic_tokenizers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from typing import Callable, Iterator
1919

2020
from lighteval.utils.imports import (
21-
Extras,
22-
raise_if_package_not_available,
21+
Extra,
2322
requires,
2423
)
2524
from lighteval.utils.language import Language
@@ -98,7 +97,7 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:
9897
return list(self.tokenizer.span_tokenize(text))
9998

10099

101-
@requires(Extras.MULTILINGUAL)
100+
@requires(Extra.MULTILINGUAL)
102101
class SpaCyTokenizer(WordTokenizer):
103102
def __init__(self, spacy_language: str, config=None):
104103
super().__init__()
@@ -139,7 +138,6 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:
139138
class StanzaTokenizer(WordTokenizer):
140139
def __init__(self, stanza_language: str, **stanza_kwargs):
141140
super().__init__()
142-
raise_if_package_not_available("stanza")
143141
self.stanza_language = stanza_language
144142
self.stanza_kwargs = stanza_kwargs
145143
self._tokenizer = None

src/lighteval/models/endpoints/tgi_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModel
3333
from lighteval.tasks.prompt_manager import PromptManager
3434
from lighteval.utils.cache_management import SampleCache
35-
from lighteval.utils.imports import is_package_available, requires
35+
from lighteval.utils.imports import Extra, is_package_available, requires
3636

3737

3838
if is_package_available("tgi"):
@@ -99,7 +99,7 @@ class TGIModelConfig(ModelConfig):
9999

100100
# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite
101101
# the client functions, since they use a different client.
102-
@requires("tgi")
102+
@requires(Extra.TGI)
103103
class ModelClient(InferenceEndpointModel):
104104
_DEFAULT_MAX_LENGTH: int = 4096
105105

src/lighteval/models/model_loader.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig
4444
from lighteval.models.transformers.vlm_transformers_model import VLMTransformersModel, VLMTransformersModelConfig
4545
from lighteval.models.vllm.vllm_model import AsyncVLLMModel, VLLMModel, VLLMModelConfig
46-
from lighteval.utils.imports import raise_if_package_not_available
4746

4847

4948
logger = logging.getLogger(__name__)
@@ -148,7 +147,6 @@ def load_model_with_accelerate_or_default(
148147
elif isinstance(config, DeltaModelConfig):
149148
model = DeltaModel(config=config)
150149
elif isinstance(config, VLLMModelConfig):
151-
raise_if_package_not_available("vllm")
152150
if config.is_async:
153151
model = AsyncVLLMModel(config=config)
154152
else:

src/lighteval/pipeline.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from lighteval.tasks.lighteval_task import LightevalTask
4343
from lighteval.tasks.registry import Registry
4444
from lighteval.tasks.requests import SamplingMethod
45-
from lighteval.utils.imports import is_package_available, raise_if_package_not_available
45+
from lighteval.utils.imports import is_package_available
4646
from lighteval.utils.parallelism import test_all_gather
4747
from lighteval.utils.utils import make_results_table, remove_reasoning_tags
4848

@@ -96,21 +96,6 @@ class PipelineParameters:
9696
bootstrap_iters: int = 1000
9797

9898
def __post_init__(self): # noqa C901
99-
# Import testing
100-
if self.launcher_type == ParallelismManager.ACCELERATE:
101-
raise_if_package_not_available("accelerate")
102-
elif self.launcher_type == ParallelismManager.VLLM:
103-
raise_if_package_not_available("vllm")
104-
elif self.launcher_type == ParallelismManager.SGLANG:
105-
raise_if_package_not_available("sglang")
106-
elif self.launcher_type == ParallelismManager.TGI:
107-
raise_if_package_not_available("tgi")
108-
elif self.launcher_type == ParallelismManager.NANOTRON:
109-
raise_if_package_not_available("nanotron")
110-
elif self.launcher_type == ParallelismManager.OPENAI:
111-
raise_if_package_not_available("openai")
112-
113-
# Convert reasoning tags to list if needed
11499
if not isinstance(self.reasoning_tags, list):
115100
try:
116101
self.reasoning_tags = ast.literal_eval(self.reasoning_tags)

src/lighteval/tasks/registry.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ def __init__(
115115
self,
116116
tasks: str | Path | None = None,
117117
custom_tasks: str | Path | ModuleType | None = None,
118-
load_community: bool = True,
119-
load_extended: bool = True,
120-
load_multilingual: bool = True,
118+
load_community: bool = False,
119+
load_extended: bool = False,
120+
load_multilingual: bool = False,
121121
):
122122
"""
123123
Initialize the Registry class.
@@ -213,6 +213,13 @@ def _activate_loading_of_optional_suite(self) -> None:
213213
f"Suite {suite_name} unknown. This is not normal, unless you are testing adding new evaluations."
214214
)
215215

216+
if "extended" in suites:
217+
self._load_extended = True
218+
if "multilingual" in suites:
219+
self._load_multilingual = True
220+
if "community" in suites:
221+
self._load_community = True
222+
216223
def _load_full_registry(self) -> dict[str, LightevalTaskConfig]:
217224
"""
218225
Returns:

src/lighteval/utils/imports.py

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,64 @@
1515
import functools
1616
import importlib
1717
import inspect
18+
import re
19+
from collections import defaultdict
1820
from functools import lru_cache
21+
from importlib.metadata import PackageNotFoundError, metadata, version
22+
from typing import Dict, List, Tuple
1923

2024
from packaging.requirements import Requirement
25+
from packaging.version import Version
2126

2227

23-
class Extras(enum.Enum):
28+
# These extras should exist in the pyproject.toml file
29+
class Extra(enum.Enum):
2430
MULTILINGUAL = "multilingual"
2531
EXTENDED = "extended"
32+
TGI = "tgi"
2633

2734

2835
@lru_cache()
29-
def is_package_available(package_name: str | Extras):
30-
if package_name == Extras.MULTILINGUAL:
31-
return all(importlib.util.find_spec(package) is not None for package in ["stanza", "spacy"])
32-
if package_name == Extras.EXTENDED:
33-
return all(importlib.util.find_spec(package) is not None for package in ["spacy"])
36+
def is_package_available(package: str | Requirement | Extra):
37+
deps, deps_by_extra = required_dependencies()
38+
39+
if isinstance(package, str):
40+
package = deps[package]
41+
42+
if isinstance(package, Extra):
43+
dependencies = deps_by_extra[package.value]
44+
return all(is_package_available(_package) for _package in dependencies)
3445
else:
35-
return importlib.util.find_spec(package_name) is not None
46+
try:
47+
installed = Version(version(package.name))
48+
except PackageNotFoundError:
49+
return False
50+
51+
# No version constraint → any installed version is OK
52+
if not package.specifier:
53+
return True
54+
55+
return installed in package.specifier
56+
57+
58+
@lru_cache()
59+
def required_dependencies() -> Tuple[Dict[str, Requirement], Dict[str, List[Requirement]]]:
60+
md = metadata("lighteval")
61+
requires_dist = md.get_all("Requires-Dist") or []
62+
deps_by_extra = defaultdict(list)
63+
deps = {}
64+
65+
for dep in requires_dist:
66+
extra = None
67+
if ";" in dep:
68+
dep, marker = dep.split(";", 1)
69+
match = re.search(r'extra\s*==\s*"(.*?)"', marker)
70+
extra = match.group(1) if match else None
71+
requirement = Requirement(dep.strip())
72+
deps_by_extra[extra].append(requirement)
73+
deps[requirement.name] = requirement
74+
75+
return deps, deps_by_extra
3676

3777

3878
@lru_cache()
@@ -50,33 +90,32 @@ def is_multilingual_package_available(language: str):
5090
return all(cur_import is not None for cur_import in imports)
5191

5292

53-
def raise_if_package_not_available(package_name: str | Extras, *, language: str = None, object_name: str = None):
93+
def raise_if_package_not_available(package: Requirement | Extra, *, language: str = None, object_name: str = None):
5494
prefix = "You" if object_name is None else f"Through the use of {object_name}, you"
5595

56-
if package_name == Extras.MULTILINGUAL and (
57-
(language is not None) or (not is_multilingual_package_available(language))
58-
):
59-
raise ImportError(prefix + not_installed_error_message(package_name)[3:])
96+
if package == Extra.MULTILINGUAL and ((language is not None) or (not is_multilingual_package_available(language))):
97+
raise ImportError(prefix + not_installed_error_message(package)[3:])
6098

61-
if not is_package_available(package_name):
62-
raise ImportError(prefix + not_installed_error_message(package_name)[3:])
99+
if not is_package_available(package):
100+
raise ImportError(prefix + not_installed_error_message(package)[3:])
63101

64102

65-
def not_installed_error_message(package_name: str | Extras) -> str:
66-
if package_name == Extras.MULTILINGUAL.value:
103+
def not_installed_error_message(package: Requirement) -> str:
104+
if package == Extra.MULTILINGUAL.value:
67105
return "You are trying to run an evaluation requiring multilingual capabilities. Please install the required extra: `pip install lighteval[multilingual]`"
68-
elif package_name == Extras.EXTENDED.value:
106+
elif package == Extra.EXTENDED.value:
69107
return "You are trying to run an evaluation requiring additional extensions. Please install the required extra: `pip install lighteval[extended] "
70-
elif package_name == "text_generation":
108+
elif package == "text_generation":
71109
return "You are trying to start a text generation inference endpoint, but TGI is not present in your local environment. Please install it using pip."
72-
elif package_name in ["bitsandbytes", "auto-gptq"]:
73-
return f"You are trying to load a model quantized with `{package_name}`, which is not available in your local environment. Please install it using pip."
74-
elif package_name == "peft":
110+
elif package == "peft":
75111
return "You are trying to use adapter weights models, for which you need `peft`, which is not available in your environment. Please install it using pip."
76-
elif package_name == "openai":
112+
elif package == "openai":
77113
return "You are trying to use an Open AI LLM as a judge, for which you need `openai`, which is not available in your environment. Please install it using pip."
78114

79-
return f"You requested the use of `{package_name}` for this evaluation, but it is not available in your current environment. Please install it using pip."
115+
if isinstance(package, Extra):
116+
return f"You are trying to run an evaluation requiring {package.value} capabilities. Please install the required extra: `pip install lighteval[{package.value}]`"
117+
else:
118+
return f"You requested the use of `{package}` for this evaluation, but it is not available in your current environment. Please install it using pip."
80119

81120

82121
class DummyObject(type):
@@ -101,9 +140,22 @@ def requires(*backends):
101140
which is not installed.
102141
"""
103142

143+
requirements, _ = required_dependencies()
144+
104145
applied_backends = []
105146
for backend in backends:
106-
applied_backends.append(Requirement(backend.value if isinstance(backend, Extras) else backend))
147+
if isinstance(backend, Extra):
148+
applied_backends.append(backend)
149+
else:
150+
if backend not in requirements:
151+
raise RuntimeError(
152+
"A dependency was specified with @requires, but it is not defined in the possible dependencies "
153+
f"defined in the pyproject.toml: `{backend}`."
154+
f""
155+
f"If editing the pyproject.toml to add a new dependency, remember to reinstall lighteval for the"
156+
f"update to take effect."
157+
)
158+
applied_backends.append(requirements[backend])
107159

108160
def inner_fn(_object):
109161
_object._backends = applied_backends
@@ -115,7 +167,7 @@ class Placeholder(metaclass=DummyObject):
115167

116168
def __init__(self, *args, **kwargs):
117169
for backend in self._backends:
118-
raise_if_package_not_available(backend.name, object_name=_object.__class__.__name__)
170+
raise_if_package_not_available(backend, object_name=_object.__name__)
119171

120172
Placeholder.__name__ = _object.__name__
121173
Placeholder.__module__ = _object.__module__
@@ -126,7 +178,7 @@ def __init__(self, *args, **kwargs):
126178
@functools.wraps(_object)
127179
def wrapper(*args, **kwargs):
128180
for backend in _object._backends:
129-
raise_if_package_not_available(backend.name, object_name=_object.__name__)
181+
raise_if_package_not_available(backend, object_name=_object.__name__)
130182
return _object(*args, **kwargs)
131183

132184
return wrapper

tests/test_dependencies.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import importlib
2626

2727
import pytest
28+
from packaging.requirements import Requirement
2829

2930
import lighteval.utils.imports as imports
3031

@@ -40,7 +41,8 @@ def decorator(test_func):
4041
def wrapper(*args, **kwargs):
4142
from unittest.mock import patch
4243

43-
def fake(name):
44+
def fake(requirement):
45+
name = requirement.name if isinstance(requirement, Requirement) else requirement
4446
return False if name in names else (importlib.util.find_spec(name) is not None)
4547

4648
with patch.object(imports, "is_package_available", side_effect=fake):
@@ -73,9 +75,9 @@ def test_multilingual_required_for_xnli():
7375

7476
with pytest.raises(
7577
ImportError,
76-
match="Through the use of get_multilingual_normalizer, you are trying to run an evaluation requiring multilingual capabilities. Please install the required extra: `pip install lighteval[multilingual]`",
78+
match="Through the use of get_multilingual_normalizer, you are trying to run an evaluation requiring multilingual capabilities.",
7779
):
78-
accelerate(model_args="model_name=gpt2,batch_size=1", tasks="lighteval|xnli_zho_mcf|0", max_samples=0)
80+
accelerate(model_args="model_name=gpt2,batch_size=1", tasks="multilingual|xnli_zho_mcf|0", max_samples=0)
7981

8082

8183
@pretend_missing("vllm")
@@ -84,6 +86,6 @@ def test_vllm_required_for_vllm_usage():
8486

8587
with pytest.raises(
8688
ImportError,
87-
match="You requested the use of `vllm` for this evaluation, but it is not available in your current environment. Please install it using pip.'",
89+
match="Through the use of VLLMModel, you requested the use of `vllm<0.10.2,>=0.10.0` for this evaluation, but it is not available in your current environment. Please install it using pip.",
8890
):
89-
vllm(model_args="model_name=gpt2,batch_size=1", tasks="lighteval|xnli_zho_mcf|0", max_samples=0)
91+
vllm(model_args="model_name=gpt2", tasks="lighteval|aime24|0", max_samples=0)

0 commit comments

Comments
 (0)