Skip to content

Commit 2a7b29a

Browse files
committed
Rework the imports to be more versatile
Style
1 parent e61b868 commit 2a7b29a

27 files changed

+144
-278
lines changed

docs/source/using-the-python-api.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ import lighteval
1212
from lighteval.logging.evaluation_tracker import EvaluationTracker
1313
from lighteval.models.vllm.vllm_model import VLLMModelConfig
1414
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
15-
from lighteval.utils.imports import is_accelerate_available
15+
from lighteval.utils.imports import is_package_available
1616

17-
if is_accelerate_available():
17+
if is_package_available("accelerate"):
1818
from datetime import timedelta
1919
from accelerate import Accelerator, InitProcessGroupKwargs
2020
accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])

src/lighteval/logging/evaluation_tracker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@
4343
TaskConfigLogger,
4444
VersionsLogger,
4545
)
46-
from lighteval.utils.imports import NO_TENSORBOARDX_WARN_MSG, is_nanotron_available, is_tensorboardX_available
46+
from lighteval.utils.imports import is_package_available, not_installed_error_message
4747
from lighteval.utils.utils import obj_to_markdown
4848

4949

5050
logger = logging.getLogger(__name__)
5151

52-
if is_nanotron_available():
52+
if is_package_available("nanotron"):
5353
from nanotron.config import GeneralArgs # type: ignore
5454

5555
try:
@@ -659,11 +659,11 @@ def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901
659659
def push_to_tensorboard( # noqa: C901
660660
self, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
661661
):
662-
if not is_tensorboardX_available:
663-
logger.warning(NO_TENSORBOARDX_WARN_MSG)
662+
if not is_package_available("tensorboardX"):
663+
logger.warning(not_installed_error_message("tensorboardX"))
664664
return
665665

666-
if not is_nanotron_available():
666+
if not is_package_available("nanotron"):
667667
logger.warning("You cannot push results to tensorboard without having nanotron installed. Skipping")
668668
return
669669

src/lighteval/logging/info_loggers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
from lighteval.models.model_output import ModelResponse
3535
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig
3636
from lighteval.tasks.requests import Doc
37-
from lighteval.utils.imports import is_nanotron_available
37+
from lighteval.utils.imports import is_package_available
3838

3939

4040
logger = logging.getLogger(__name__)
4141

4242

43-
if is_nanotron_available():
43+
if is_package_available("nanotron"):
4444
pass
4545

4646

src/lighteval/main_nanotron.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@
3232
reasoning_tags,
3333
remove_reasoning_tags,
3434
)
35+
from lighteval.utils.imports import requires
3536

3637

3738
SEED = 1234
3839

3940

41+
@requires("nanotron")
4042
def nanotron(
4143
checkpoint_config_path: Annotated[
4244
str, Option(help="Path to the nanotron checkpoint YAML or python config file, potentially on s3.")
@@ -45,12 +47,9 @@ def nanotron(
4547
remove_reasoning_tags: remove_reasoning_tags.type = remove_reasoning_tags.default,
4648
reasoning_tags: reasoning_tags.type = reasoning_tags.default,
4749
):
48-
"""Evaluate models using nanotron as backend."""
49-
from lighteval.utils.imports import NO_NANOTRON_ERROR_MSG, is_nanotron_available
50-
51-
if not is_nanotron_available():
52-
raise ImportError(NO_NANOTRON_ERROR_MSG)
53-
50+
"""
51+
Evaluate models using nanotron as backend.
52+
"""
5453
from nanotron.config import GeneralArgs, ModelArgs, TokenizerArgs, get_config_from_dict, get_config_from_file
5554

5655
from lighteval.logging.evaluation_tracker import EvaluationTracker

src/lighteval/metrics/imports/data_stats_metric.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from typing import Literal
3131

3232
from lighteval.metrics.imports.data_stats_utils import Fragments
33-
from lighteval.utils.imports import NO_SPACY_ERROR_MSG, is_spacy_available
33+
from lighteval.utils.imports import raise_if_package_not_available
3434

3535

3636
logger = logging.getLogger(__name__)
@@ -86,8 +86,7 @@ def __init__(
8686
determines the spaCy model used for tokenization. Currently supports English,
8787
German, French, and Italian.
8888
"""
89-
if not is_spacy_available():
90-
raise ImportError(NO_SPACY_ERROR_MSG)
89+
raise_if_package_not_available("spacy")
9190
import spacy
9291

9392
self.n_gram = n_gram

src/lighteval/metrics/utils/extractive_match_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@
3434
from lighteval.tasks.requests import Doc
3535
from lighteval.tasks.templates.utils.formulation import ChoicePrefix, get_prefix
3636
from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS
37-
from lighteval.utils.imports import requires_latex2sympy2_extended
37+
from lighteval.utils.imports import requires
3838
from lighteval.utils.language import Language
3939
from lighteval.utils.timeout import timeout
4040

4141

42-
@requires_latex2sympy2_extended
42+
@requires("latex2sympy2_extended")
4343
def latex_normalization_config_default_factory():
4444
from latex2sympy2_extended.latex2sympy2 import NormalizationConfig
4545

@@ -373,7 +373,7 @@ def get_target_type_order(target_type: ExtractionTarget) -> int:
373373

374374
# Small cache, to catche repeated calls invalid parsing
375375
@lru_cache(maxsize=20)
376-
@requires_latex2sympy2_extended
376+
@requires("latex2sympy2_extended")
377377
def parse_latex_with_timeout(latex: str, timeout_seconds: int):
378378
from latex2sympy2_extended.latex2sympy2 import latex2sympy
379379

@@ -428,7 +428,7 @@ def convert_to_pct(number: Number):
428428
return sympy.Mul(number, sympy.Rational(1, 100), evaluate=False)
429429

430430

431-
@requires_latex2sympy2_extended
431+
@requires("latex2sympy2_extended")
432432
@lru_cache(maxsize=20)
433433
def extract_latex(
434434
match: re.Match, latex_config: LatexExtractionConfig, timeout_seconds: int

src/lighteval/metrics/utils/linguistic_tokenizers.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
from typing import Callable, Iterator
1919

2020
from lighteval.utils.imports import (
21-
NO_SPACY_TOKENIZER_ERROR_MSG,
22-
NO_STANZA_TOKENIZER_ERROR_MSG,
23-
can_load_spacy_tokenizer,
24-
can_load_stanza_tokenizer,
21+
Extras,
22+
raise_if_package_not_available,
2523
)
2624
from lighteval.utils.language import Language
2725

@@ -102,8 +100,8 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:
102100
class SpaCyTokenizer(WordTokenizer):
103101
def __init__(self, spacy_language: str, config=None):
104102
super().__init__()
105-
if not can_load_spacy_tokenizer(spacy_language):
106-
raise ImportError(NO_SPACY_TOKENIZER_ERROR_MSG)
103+
raise_if_package_not_available(Extras.MULTILINGUAL, language=spacy_language)
104+
107105
self.spacy_language = spacy_language
108106
self.config = config
109107
self._tokenizer = None
@@ -140,8 +138,7 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:
140138
class StanzaTokenizer(WordTokenizer):
141139
def __init__(self, stanza_language: str, **stanza_kwargs):
142140
super().__init__()
143-
if not can_load_stanza_tokenizer():
144-
raise ImportError(NO_STANZA_TOKENIZER_ERROR_MSG)
141+
raise_if_package_not_available("stanza")
145142
self.stanza_language = stanza_language
146143
self.stanza_kwargs = stanza_kwargs
147144
self._tokenizer = None

src/lighteval/metrics/utils/llm_as_judge.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from tqdm import tqdm
3535
from tqdm.asyncio import tqdm_asyncio
3636

37-
from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available
37+
from lighteval.utils.imports import raise_if_package_not_available
3838
from lighteval.utils.utils import as_list
3939

4040

@@ -151,8 +151,7 @@ def __lazy_load_client(self): # noqa: C901
151151
# Both "openai" and "tgi" backends use the OpenAI-compatible API
152152
# They are handled separately to allow for backend-specific validation and setup
153153
case "openai" | "tgi":
154-
if not is_openai_available():
155-
raise RuntimeError("OpenAI backend is not available.")
154+
raise_if_package_not_available("openai")
156155
if self.client is None:
157156
from openai import OpenAI
158157

@@ -162,13 +161,11 @@ def __lazy_load_client(self): # noqa: C901
162161
return self.__call_api_parallel
163162

164163
case "litellm":
165-
if not is_litellm_available():
166-
raise RuntimeError("litellm is not available.")
164+
raise_if_package_not_available("litellm")
167165
return self.__call_litellm
168166

169167
case "vllm":
170-
if not is_vllm_available():
171-
raise RuntimeError("vllm is not available.")
168+
raise_if_package_not_available("vllm")
172169
if self.pipe is None:
173170
from vllm import LLM, SamplingParams
174171
from vllm.transformers_utils.tokenizer import get_tokenizer

src/lighteval/metrics/utils/math_comparison.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from sympy.core.function import UndefinedFunction
5252
from sympy.core.relational import Relational
5353

54-
from lighteval.utils.imports import requires_latex2sympy2_extended
54+
from lighteval.utils.imports import requires
5555
from lighteval.utils.timeout import timeout
5656

5757

@@ -308,7 +308,7 @@ def is_equation(expr: Basic | MatrixBase) -> bool:
308308
return False
309309

310310

311-
@requires_latex2sympy2_extended
311+
@requires("latex2sympy2_extended")
312312
def is_assignment_relation(expr: Basic | MatrixBase) -> bool:
313313
from latex2sympy2_extended.latex2sympy2 import is_expr_of_only_symbols
314314

src/lighteval/models/endpoints/litellm_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@
3232
from lighteval.tasks.prompt_manager import PromptManager
3333
from lighteval.tasks.requests import Doc, SamplingMethod
3434
from lighteval.utils.cache_management import SampleCache, cached
35-
from lighteval.utils.imports import is_litellm_available
35+
from lighteval.utils.imports import is_package_available, requires
3636

3737

3838
logger = logging.getLogger(__name__)
3939

40-
if is_litellm_available():
40+
if is_package_available("litellm"):
4141
import litellm
4242
from litellm import encode
4343
from litellm.caching.caching import Cache
@@ -110,6 +110,7 @@ class LiteLLMModelConfig(ModelConfig):
110110
concurrent_requests: int = 10
111111

112112

113+
@requires("litellm")
113114
class LiteLLMClient(LightevalModel):
114115
_DEFAULT_MAX_LENGTH: int = 4096
115116

0 commit comments

Comments
 (0)