Skip to content

Commit 6ee4ff8

Browse files
authored
[RFC] Rework the dependencies to be more versatile (#951)
This PR proposes a change in the way optional dependencies are managed so that it's easier to shield tasks that have specific dependencies. See below for an explanation of the different changes. The first change renames the `is_xxx_available() -> bool` methods to a more versatile `is_package_available(package_name: str) -> bool`. This unbloats the `imports.py` module while adding two additional methods that work on top of it: - `is_multilingual_package_available(language_code: Optional[str] = None) -> bool` which checks the presence of `spacy`, and optionally takes in a language code to check whether the language-related dependencies are installed. - the `@requires(package_name)` decorator which can be added on top of classes/methods to raise an error as soon as any method of the class/the method is called, if the package isn't installed. FYI: this is a draft PR that isn't tested at this point, and that I need to test in a real-world scenario. I'd like to validate the direction before adding relevant tests.
1 parent e61b868 commit 6ee4ff8

33 files changed

+361
-312
lines changed

docs/source/using-the-python-api.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ import lighteval
1212
from lighteval.logging.evaluation_tracker import EvaluationTracker
1313
from lighteval.models.vllm.vllm_model import VLLMModelConfig
1414
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
15-
from lighteval.utils.imports import is_accelerate_available
15+
from lighteval.utils.imports import is_package_available
1616

17-
if is_accelerate_available():
17+
if is_package_available("accelerate"):
1818
from datetime import timedelta
1919
from accelerate import Accelerator, InitProcessGroupKwargs
2020
accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ dependencies = [
8484
"fsspec>=2023.12.2",
8585
"httpx>=0.27.2",
8686
"latex2sympy2_extended==1.0.6",
87+
"langcodes",
8788
]
8889

8990
[project.optional-dependencies]
@@ -98,6 +99,7 @@ nanotron = [
9899
]
99100
tensorboardX = ["tensorboardX"]
100101
vllm = ["vllm>=0.10.0,<0.10.2", "ray", "more_itertools"]
102+
sglang = ["sglang"]
101103
quality = ["ruff>=v0.11.0","pre-commit"]
102104
tests = ["pytest>=7.4.0","deepdiff","pip>=25.2"]
103105
dev = ["lighteval[accelerate,quality,tests,multilingual,math,extended_tasks,vllm]"]

src/lighteval/logging/evaluation_tracker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@
4343
TaskConfigLogger,
4444
VersionsLogger,
4545
)
46-
from lighteval.utils.imports import NO_TENSORBOARDX_WARN_MSG, is_nanotron_available, is_tensorboardX_available
46+
from lighteval.utils.imports import is_package_available, not_installed_error_message
4747
from lighteval.utils.utils import obj_to_markdown
4848

4949

5050
logger = logging.getLogger(__name__)
5151

52-
if is_nanotron_available():
52+
if is_package_available("nanotron"):
5353
from nanotron.config import GeneralArgs # type: ignore
5454

5555
try:
@@ -659,11 +659,11 @@ def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901
659659
def push_to_tensorboard( # noqa: C901
660660
self, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
661661
):
662-
if not is_tensorboardX_available:
663-
logger.warning(NO_TENSORBOARDX_WARN_MSG)
662+
if not is_package_available("tensorboardX"):
663+
logger.warning(not_installed_error_message("tensorboardX"))
664664
return
665665

666-
if not is_nanotron_available():
666+
if not is_package_available("nanotron"):
667667
logger.warning("You cannot push results to tensorboard without having nanotron installed. Skipping")
668668
return
669669

src/lighteval/logging/info_loggers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
from lighteval.models.model_output import ModelResponse
3535
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig
3636
from lighteval.tasks.requests import Doc
37-
from lighteval.utils.imports import is_nanotron_available
37+
from lighteval.utils.imports import is_package_available
3838

3939

4040
logger = logging.getLogger(__name__)
4141

4242

43-
if is_nanotron_available():
43+
if is_package_available("nanotron"):
4444
pass
4545

4646

src/lighteval/main_nanotron.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@
3232
reasoning_tags,
3333
remove_reasoning_tags,
3434
)
35+
from lighteval.utils.imports import requires
3536

3637

3738
SEED = 1234
3839

3940

41+
@requires("nanotron")
4042
def nanotron(
4143
checkpoint_config_path: Annotated[
4244
str, Option(help="Path to the nanotron checkpoint YAML or python config file, potentially on s3.")
@@ -45,12 +47,9 @@ def nanotron(
4547
remove_reasoning_tags: remove_reasoning_tags.type = remove_reasoning_tags.default,
4648
reasoning_tags: reasoning_tags.type = reasoning_tags.default,
4749
):
48-
"""Evaluate models using nanotron as backend."""
49-
from lighteval.utils.imports import NO_NANOTRON_ERROR_MSG, is_nanotron_available
50-
51-
if not is_nanotron_available():
52-
raise ImportError(NO_NANOTRON_ERROR_MSG)
53-
50+
"""
51+
Evaluate models using nanotron as backend.
52+
"""
5453
from nanotron.config import GeneralArgs, ModelArgs, TokenizerArgs, get_config_from_dict, get_config_from_file
5554

5655
from lighteval.logging.evaluation_tracker import EvaluationTracker

src/lighteval/metrics/imports/data_stats_metric.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from typing import Literal
3131

3232
from lighteval.metrics.imports.data_stats_utils import Fragments
33-
from lighteval.utils.imports import NO_SPACY_ERROR_MSG, is_spacy_available
33+
from lighteval.utils.imports import Extra, requires
3434

3535

3636
logger = logging.getLogger(__name__)
@@ -55,6 +55,7 @@ def find_ngrams(input_list, n):
5555
return zip(*[input_list[i:] for i in range(n)])
5656

5757

58+
@requires(Extra.MULTILINGUAL)
5859
class DataStatsMetric(Metric):
5960
def __init__(
6061
self,
@@ -86,8 +87,6 @@ def __init__(
8687
determines the spaCy model used for tokenization. Currently supports English,
8788
German, French, and Italian.
8889
"""
89-
if not is_spacy_available():
90-
raise ImportError(NO_SPACY_ERROR_MSG)
9190
import spacy
9291

9392
self.n_gram = n_gram

src/lighteval/metrics/normalizations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from typing import Callable
2929

3030
from lighteval.metrics.utils.linguistic_tokenizers import get_word_tokenizer
31+
from lighteval.utils.imports import Extra, requires
3132
from lighteval.utils.language import Language
3233

3334

@@ -444,15 +445,16 @@ def remove_punc(text: str) -> str:
444445
return "".join(ch for ch in text if ch not in PUNCT)
445446

446447

448+
@requires(Extra.MULTILINGUAL)
447449
def get_multilingual_normalizer(lang: Language, lower: bool = True) -> Callable[[str], str]:
448450
"""Get a normalizer function for the specified language.
449451
450452
Returns:
451453
Callable[[str], str]: A function that normalizes text for the specified language
452454
"""
453-
tokenizer = get_word_tokenizer(lang)
454455

455456
def _inner_normalizer(text: str) -> str:
457+
tokenizer = get_word_tokenizer(lang)
456458
text = remove_articles(text, lang)
457459
text = remove_punc(text)
458460
if lower:

src/lighteval/metrics/utils/extractive_match_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@
3434
from lighteval.tasks.requests import Doc
3535
from lighteval.tasks.templates.utils.formulation import ChoicePrefix, get_prefix
3636
from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS
37-
from lighteval.utils.imports import requires_latex2sympy2_extended
37+
from lighteval.utils.imports import requires
3838
from lighteval.utils.language import Language
3939
from lighteval.utils.timeout import timeout
4040

4141

42-
@requires_latex2sympy2_extended
42+
@requires("latex2sympy2_extended")
4343
def latex_normalization_config_default_factory():
4444
from latex2sympy2_extended.latex2sympy2 import NormalizationConfig
4545

@@ -373,7 +373,7 @@ def get_target_type_order(target_type: ExtractionTarget) -> int:
373373

374374
# Small cache, to catche repeated calls invalid parsing
375375
@lru_cache(maxsize=20)
376-
@requires_latex2sympy2_extended
376+
@requires("latex2sympy2_extended")
377377
def parse_latex_with_timeout(latex: str, timeout_seconds: int):
378378
from latex2sympy2_extended.latex2sympy2 import latex2sympy
379379

@@ -428,7 +428,7 @@ def convert_to_pct(number: Number):
428428
return sympy.Mul(number, sympy.Rational(1, 100), evaluate=False)
429429

430430

431-
@requires_latex2sympy2_extended
431+
@requires("latex2sympy2_extended")
432432
@lru_cache(maxsize=20)
433433
def extract_latex(
434434
match: re.Match, latex_config: LatexExtractionConfig, timeout_seconds: int

src/lighteval/metrics/utils/linguistic_tokenizers.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
from typing import Callable, Iterator
1919

2020
from lighteval.utils.imports import (
21-
NO_SPACY_TOKENIZER_ERROR_MSG,
22-
NO_STANZA_TOKENIZER_ERROR_MSG,
23-
can_load_spacy_tokenizer,
24-
can_load_stanza_tokenizer,
21+
Extra,
22+
requires,
2523
)
2624
from lighteval.utils.language import Language
2725

@@ -99,11 +97,10 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:
9997
return list(self.tokenizer.span_tokenize(text))
10098

10199

100+
@requires(Extra.MULTILINGUAL)
102101
class SpaCyTokenizer(WordTokenizer):
103102
def __init__(self, spacy_language: str, config=None):
104103
super().__init__()
105-
if not can_load_spacy_tokenizer(spacy_language):
106-
raise ImportError(NO_SPACY_TOKENIZER_ERROR_MSG)
107104
self.spacy_language = spacy_language
108105
self.config = config
109106
self._tokenizer = None
@@ -137,11 +134,10 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:
137134
]
138135

139136

137+
@requires("stanza")
140138
class StanzaTokenizer(WordTokenizer):
141139
def __init__(self, stanza_language: str, **stanza_kwargs):
142140
super().__init__()
143-
if not can_load_stanza_tokenizer():
144-
raise ImportError(NO_STANZA_TOKENIZER_ERROR_MSG)
145141
self.stanza_language = stanza_language
146142
self.stanza_kwargs = stanza_kwargs
147143
self._tokenizer = None

src/lighteval/metrics/utils/llm_as_judge.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from tqdm import tqdm
3535
from tqdm.asyncio import tqdm_asyncio
3636

37-
from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available
37+
from lighteval.utils.imports import raise_if_package_not_available
3838
from lighteval.utils.utils import as_list
3939

4040

@@ -151,8 +151,7 @@ def __lazy_load_client(self): # noqa: C901
151151
# Both "openai" and "tgi" backends use the OpenAI-compatible API
152152
# They are handled separately to allow for backend-specific validation and setup
153153
case "openai" | "tgi":
154-
if not is_openai_available():
155-
raise RuntimeError("OpenAI backend is not available.")
154+
raise_if_package_not_available("openai")
156155
if self.client is None:
157156
from openai import OpenAI
158157

@@ -162,13 +161,11 @@ def __lazy_load_client(self): # noqa: C901
162161
return self.__call_api_parallel
163162

164163
case "litellm":
165-
if not is_litellm_available():
166-
raise RuntimeError("litellm is not available.")
164+
raise_if_package_not_available("litellm")
167165
return self.__call_litellm
168166

169167
case "vllm":
170-
if not is_vllm_available():
171-
raise RuntimeError("vllm is not available.")
168+
raise_if_package_not_available("vllm")
172169
if self.pipe is None:
173170
from vllm import LLM, SamplingParams
174171
from vllm.transformers_utils.tokenizer import get_tokenizer

0 commit comments

Comments
 (0)