diff --git a/docs/source/using-the-python-api.mdx b/docs/source/using-the-python-api.mdx index 1a21ebe4b..f09802272 100644 --- a/docs/source/using-the-python-api.mdx +++ b/docs/source/using-the-python-api.mdx @@ -12,9 +12,9 @@ import lighteval from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.models.vllm.vllm_model import VLLMModelConfig from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters -from lighteval.utils.imports import is_accelerate_available +from lighteval.utils.imports import is_package_available -if is_accelerate_available(): +if is_package_available("accelerate"): from datetime import timedelta from accelerate import Accelerator, InitProcessGroupKwargs accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))]) diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 108877601..aed32d2f1 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -43,13 +43,13 @@ TaskConfigLogger, VersionsLogger, ) -from lighteval.utils.imports import NO_TENSORBOARDX_WARN_MSG, is_nanotron_available, is_tensorboardX_available +from lighteval.utils.imports import is_package_available, not_installed_error_message from lighteval.utils.utils import obj_to_markdown logger = logging.getLogger(__name__) -if is_nanotron_available(): +if is_package_available("nanotron"): from nanotron.config import GeneralArgs # type: ignore try: @@ -659,11 +659,11 @@ def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901 def push_to_tensorboard( # noqa: C901 self, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail] ): - if not is_tensorboardX_available: - logger.warning(NO_TENSORBOARDX_WARN_MSG) + if not is_package_available("tensorboardX"): + logger.warning(not_installed_error_message("tensorboardX")) return - if not is_nanotron_available(): + if not is_package_available("nanotron"): logger.warning("You cannot push results to tensorboard without having nanotron installed. Skipping") return diff --git a/src/lighteval/logging/info_loggers.py b/src/lighteval/logging/info_loggers.py index 64019ecf8..4482fabb2 100644 --- a/src/lighteval/logging/info_loggers.py +++ b/src/lighteval/logging/info_loggers.py @@ -34,13 +34,13 @@ from lighteval.models.model_output import ModelResponse from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig from lighteval.tasks.requests import Doc -from lighteval.utils.imports import is_nanotron_available +from lighteval.utils.imports import is_package_available logger = logging.getLogger(__name__) -if is_nanotron_available(): +if is_package_available("nanotron"): pass diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index 1131aea33..b844a74a4 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -32,11 +32,13 @@ reasoning_tags, remove_reasoning_tags, ) +from lighteval.utils.imports import requires SEED = 1234 +@requires("nanotron") def nanotron( checkpoint_config_path: Annotated[ str, Option(help="Path to the nanotron checkpoint YAML or python config file, potentially on s3.") @@ -45,12 +47,9 @@ def nanotron( remove_reasoning_tags: remove_reasoning_tags.type = remove_reasoning_tags.default, reasoning_tags: reasoning_tags.type = reasoning_tags.default, ): - """Evaluate models using nanotron as backend.""" - from lighteval.utils.imports import NO_NANOTRON_ERROR_MSG, is_nanotron_available - - if not is_nanotron_available(): - raise ImportError(NO_NANOTRON_ERROR_MSG) - + """ + Evaluate models using nanotron as backend. + """ from nanotron.config import GeneralArgs, ModelArgs, TokenizerArgs, get_config_from_dict, get_config_from_file from lighteval.logging.evaluation_tracker import EvaluationTracker diff --git a/src/lighteval/metrics/imports/data_stats_metric.py b/src/lighteval/metrics/imports/data_stats_metric.py index bc2abf175..f5455eb99 100644 --- a/src/lighteval/metrics/imports/data_stats_metric.py +++ b/src/lighteval/metrics/imports/data_stats_metric.py @@ -29,7 +29,7 @@ from multiprocessing import Pool from lighteval.metrics.imports.data_stats_utils import Fragments -from lighteval.utils.imports import NO_SPACY_ERROR_MSG, is_spacy_available +from lighteval.utils.imports import raise_if_package_not_available logger = logging.getLogger(__name__) @@ -70,8 +70,7 @@ def __init__(self, n_gram=3, n_workers=24, case=False, tokenize=True): tokenize (bool): whether to tokenize the input; otherwise assumes that the input is a string of space-separated tokens. """ - if not is_spacy_available(): - raise ImportError(NO_SPACY_ERROR_MSG) + raise_if_package_not_available("spacy") import spacy self.n_gram = n_gram diff --git a/src/lighteval/metrics/utils/extractive_match_utils.py b/src/lighteval/metrics/utils/extractive_match_utils.py index d16145aea..cce2b1793 100644 --- a/src/lighteval/metrics/utils/extractive_match_utils.py +++ b/src/lighteval/metrics/utils/extractive_match_utils.py @@ -34,12 +34,12 @@ from lighteval.tasks.requests import Doc from lighteval.tasks.templates.utils.formulation import ChoicePrefix, get_prefix from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS -from lighteval.utils.imports import requires_latex2sympy2_extended +from lighteval.utils.imports import requires from lighteval.utils.language import Language from lighteval.utils.timeout import timeout -@requires_latex2sympy2_extended +@requires("latex2sympy2_extended") def latex_normalization_config_default_factory(): from latex2sympy2_extended.latex2sympy2 import NormalizationConfig @@ -373,7 +373,7 @@ def get_target_type_order(target_type: ExtractionTarget) -> int: # Small cache, to catche repeated calls invalid parsing @lru_cache(maxsize=20) -@requires_latex2sympy2_extended +@requires("latex2sympy2_extended") def parse_latex_with_timeout(latex: str, timeout_seconds: int): from latex2sympy2_extended.latex2sympy2 import latex2sympy @@ -428,7 +428,7 @@ def convert_to_pct(number: Number): return sympy.Mul(number, sympy.Rational(1, 100), evaluate=False) -@requires_latex2sympy2_extended +@requires("latex2sympy2_extended") @lru_cache(maxsize=20) def extract_latex( match: re.Match, latex_config: LatexExtractionConfig, timeout_seconds: int diff --git a/src/lighteval/metrics/utils/linguistic_tokenizers.py b/src/lighteval/metrics/utils/linguistic_tokenizers.py index e0dd9ef1a..137ac3417 100644 --- a/src/lighteval/metrics/utils/linguistic_tokenizers.py +++ b/src/lighteval/metrics/utils/linguistic_tokenizers.py @@ -18,10 +18,8 @@ from typing import Callable, Iterator from lighteval.utils.imports import ( - NO_SPACY_TOKENIZER_ERROR_MSG, - NO_STANZA_TOKENIZER_ERROR_MSG, - can_load_spacy_tokenizer, - can_load_stanza_tokenizer, + Extras, + raise_if_package_not_available, ) from lighteval.utils.language import Language @@ -102,8 +100,8 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]: class SpaCyTokenizer(WordTokenizer): def __init__(self, spacy_language: str, config=None): super().__init__() - if not can_load_spacy_tokenizer(spacy_language): - raise ImportError(NO_SPACY_TOKENIZER_ERROR_MSG) + raise_if_package_not_available(Extras.MULTILINGUAL, language=spacy_language) + self.spacy_language = spacy_language self.config = config self._tokenizer = None @@ -140,8 +138,7 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]: class StanzaTokenizer(WordTokenizer): def __init__(self, stanza_language: str, **stanza_kwargs): super().__init__() - if not can_load_stanza_tokenizer(): - raise ImportError(NO_STANZA_TOKENIZER_ERROR_MSG) + raise_if_package_not_available("stanza") self.stanza_language = stanza_language self.stanza_kwargs = stanza_kwargs self._tokenizer = None diff --git a/src/lighteval/metrics/utils/llm_as_judge.py b/src/lighteval/metrics/utils/llm_as_judge.py index dcf0a5a88..f2a48ea27 100644 --- a/src/lighteval/metrics/utils/llm_as_judge.py +++ b/src/lighteval/metrics/utils/llm_as_judge.py @@ -33,7 +33,7 @@ from tqdm import tqdm from tqdm.asyncio import tqdm_asyncio -from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available +from lighteval.utils.imports import raise_if_package_not_available from lighteval.utils.utils import as_list @@ -131,8 +131,7 @@ def __lazy_load_client(self): # noqa: C901 # Both "openai" and "tgi" backends use the OpenAI-compatible API # They are handled separately to allow for backend-specific validation and setup case "openai" | "tgi": - if not is_openai_available(): - raise RuntimeError("OpenAI backend is not available.") + raise_if_package_not_available("openai") if self.client is None: from openai import OpenAI @@ -142,13 +141,11 @@ def __lazy_load_client(self): # noqa: C901 return self.__call_api_parallel case "litellm": - if not is_litellm_available(): - raise RuntimeError("litellm is not available.") + raise_if_package_not_available("litellm") return self.__call_litellm case "vllm": - if not is_vllm_available(): - raise RuntimeError("vllm is not available.") + raise_if_package_not_available("vllm") if self.pipe is None: from vllm import LLM, SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer diff --git a/src/lighteval/metrics/utils/math_comparison.py b/src/lighteval/metrics/utils/math_comparison.py index 2650ee335..2329acfe0 100644 --- a/src/lighteval/metrics/utils/math_comparison.py +++ b/src/lighteval/metrics/utils/math_comparison.py @@ -51,7 +51,7 @@ from sympy.core.function import UndefinedFunction from sympy.core.relational import Relational -from lighteval.utils.imports import requires_latex2sympy2_extended +from lighteval.utils.imports import requires from lighteval.utils.timeout import timeout @@ -308,7 +308,7 @@ def is_equation(expr: Basic | MatrixBase) -> bool: return False -@requires_latex2sympy2_extended +@requires("latex2sympy2_extended") def is_assignment_relation(expr: Basic | MatrixBase) -> bool: from latex2sympy2_extended.latex2sympy2 import is_expr_of_only_symbols diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index 544123e00..0b436c691 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -32,12 +32,12 @@ from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import Doc from lighteval.utils.cache_management import SampleCache, cached -from lighteval.utils.imports import is_litellm_available +from lighteval.utils.imports import is_package_available, requires logger = logging.getLogger(__name__) -if is_litellm_available(): +if is_package_available("litellm"): import litellm from litellm import encode from litellm.caching.caching import Cache @@ -110,6 +110,7 @@ class LiteLLMModelConfig(ModelConfig): concurrent_requests: int = 10 +@requires("litellm") class LiteLLMClient(LightevalModel): _DEFAULT_MAX_LENGTH: int = 4096 diff --git a/src/lighteval/models/endpoints/tgi_model.py b/src/lighteval/models/endpoints/tgi_model.py index 8130cba88..4b4847fe9 100644 --- a/src/lighteval/models/endpoints/tgi_model.py +++ b/src/lighteval/models/endpoints/tgi_model.py @@ -32,10 +32,10 @@ from lighteval.models.endpoints.endpoint_model import InferenceEndpointModel from lighteval.tasks.prompt_manager import PromptManager from lighteval.utils.cache_management import SampleCache -from lighteval.utils.imports import NO_TGI_ERROR_MSG, is_tgi_available +from lighteval.utils.imports import is_package_available, requires -if is_tgi_available(): +if is_package_available("tgi"): from text_generation import AsyncClient else: from unittest.mock import Mock @@ -99,12 +99,11 @@ class TGIModelConfig(ModelConfig): # inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite # the client functions, since they use a different client. +@requires("tgi") class ModelClient(InferenceEndpointModel): _DEFAULT_MAX_LENGTH: int = 4096 def __init__(self, config: TGIModelConfig) -> None: - if not is_tgi_available(): - raise ImportError(NO_TGI_ERROR_MSG) headers = ( {} if config.inference_server_auth is None else {"Authorization": f"Bearer {config.inference_server_auth}"} ) diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index ccae20a5f..46129960d 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -43,16 +43,7 @@ from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig from lighteval.models.transformers.vlm_transformers_model import VLMTransformersModel, VLMTransformersModelConfig from lighteval.models.vllm.vllm_model import AsyncVLLMModel, VLLMModel, VLLMModelConfig -from lighteval.utils.imports import ( - NO_LITELLM_ERROR_MSG, - NO_SGLANG_ERROR_MSG, - NO_TGI_ERROR_MSG, - NO_VLLM_ERROR_MSG, - is_litellm_available, - is_sglang_available, - is_tgi_available, - is_vllm_available, -) +from lighteval.utils.imports import raise_if_package_not_available, requires logger = logging.getLogger(__name__) @@ -101,19 +92,15 @@ def load_model( # noqa: C901 return load_inference_providers_model(config=config) +@requires("tgi") def load_model_with_tgi(config: TGIModelConfig): - if not is_tgi_available(): - raise ImportError(NO_TGI_ERROR_MSG) - logger.info(f"Load model from inference server: {config.inference_server_address}") model = ModelClient(config=config) return model +@requires("litellm") def load_litellm_model(config: LiteLLMModelConfig): - if not is_litellm_available(): - raise ImportError(NO_LITELLM_ERROR_MSG) - model = LiteLLMClient(config) return model @@ -163,8 +150,7 @@ def load_model_with_accelerate_or_default( elif isinstance(config, DeltaModelConfig): model = DeltaModel(config=config) elif isinstance(config, VLLMModelConfig): - if not is_vllm_available(): - raise ImportError(NO_VLLM_ERROR_MSG) + raise_if_package_not_available("vllm") if config.is_async: model = AsyncVLLMModel(config=config) else: @@ -185,8 +171,6 @@ def load_inference_providers_model(config: InferenceProvidersModelConfig): return InferenceProvidersClient(config=config) +@requires("sglang") def load_sglang_model(config: SGLangModelConfig): - if not is_sglang_available(): - raise ImportError(NO_SGLANG_ERROR_MSG) - return SGLangModel(config=config) diff --git a/src/lighteval/models/nanotron/nanotron_model.py b/src/lighteval/models/nanotron/nanotron_model.py index 686111e04..0da9e7bdb 100644 --- a/src/lighteval/models/nanotron/nanotron_model.py +++ b/src/lighteval/models/nanotron/nanotron_model.py @@ -50,7 +50,7 @@ Doc, ) from lighteval.utils.cache_management import SampleCache, cached -from lighteval.utils.imports import is_nanotron_available +from lighteval.utils.imports import is_package_available from lighteval.utils.parallelism import find_executable_batch_size from lighteval.utils.utils import as_list @@ -62,7 +62,7 @@ TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding] -if is_nanotron_available(): +if is_package_available("nanotron"): from nanotron import distributed as dist from nanotron import logging from nanotron.config import GeneralArgs, ModelArgs, TokenizerArgs diff --git a/src/lighteval/models/sglang/sglang_model.py b/src/lighteval/models/sglang/sglang_model.py index 3c39315a8..771e88565 100644 --- a/src/lighteval/models/sglang/sglang_model.py +++ b/src/lighteval/models/sglang/sglang_model.py @@ -35,12 +35,12 @@ from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import Doc from lighteval.utils.cache_management import SampleCache, cached -from lighteval.utils.imports import is_sglang_available +from lighteval.utils.imports import is_package_available, requires logger = logging.getLogger(__name__) -if is_sglang_available(): +if is_package_available("sglang"): from sglang import Engine from sglang.srt.hf_transformers_utils import get_tokenizer @@ -138,6 +138,7 @@ class SGLangModelConfig(ModelConfig): override_chat_template: bool = None +@requires("sglang") class SGLangModel(LightevalModel): def __init__( self, diff --git a/src/lighteval/models/transformers/adapter_model.py b/src/lighteval/models/transformers/adapter_model.py index a868ad20f..52f339664 100644 --- a/src/lighteval/models/transformers/adapter_model.py +++ b/src/lighteval/models/transformers/adapter_model.py @@ -30,15 +30,16 @@ from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig from lighteval.models.utils import _get_dtype -from lighteval.utils.imports import NO_PEFT_ERROR_MSG, is_peft_available +from lighteval.utils.imports import is_package_available, requires logger = logging.getLogger(__name__) -if is_peft_available(): +if is_package_available("peft"): from peft import PeftModel +@requires("peft") class AdapterModelConfig(TransformersModelConfig): """Configuration class for PEFT (Parameter-Efficient Fine-Tuning) adapter models. @@ -58,10 +59,6 @@ class AdapterModelConfig(TransformersModelConfig): base_model: str - def model_post_init(self, __context): - if not is_peft_available(): - raise ImportError(NO_PEFT_ERROR_MSG) - class AdapterModel(TransformersModel): def _create_auto_model(self) -> transformers.PreTrainedModel: diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index 3d0d7c12b..94ac60426 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -55,7 +55,7 @@ from lighteval.tasks.requests import Doc from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import ( - is_accelerate_available, + is_package_available, ) from lighteval.utils.parallelism import find_executable_batch_size @@ -227,7 +227,7 @@ def __init__( self.model_name = _simplify_name(config.model_name) - if is_accelerate_available(): + if is_package_available("accelerate"): model_size, _ = calculate_maximum_sizes(self.model) model_size = convert_bytes(model_size) else: @@ -290,7 +290,7 @@ def from_model( else: self._device = self.config.device - if is_accelerate_available(): + if is_package_available("accelerate"): model_size, _ = calculate_maximum_sizes(self.model) model_size = convert_bytes(model_size) else: @@ -331,7 +331,7 @@ def disable_tqdm(self) -> bool: def init_model_parallel(self, model_parallel: bool | None = None) -> Tuple[bool, Optional[dict], Optional[str]]: """Compute all the parameters related to model_parallel""" - if not is_accelerate_available(): + if not is_package_available("accelerate"): return False, None, None self.num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) diff --git a/src/lighteval/models/transformers/vlm_transformers_model.py b/src/lighteval/models/transformers/vlm_transformers_model.py index 61f9d58ab..6357559fd 100644 --- a/src/lighteval/models/transformers/vlm_transformers_model.py +++ b/src/lighteval/models/transformers/vlm_transformers_model.py @@ -47,7 +47,7 @@ from lighteval.tasks.requests import Doc from lighteval.utils.cache_management import SampleCache, cached from lighteval.utils.imports import ( - is_accelerate_available, + is_package_available, ) @@ -210,7 +210,7 @@ def disable_tqdm(self) -> bool: # Copied from ./transformers_model.py def init_model_parallel(self, model_parallel: bool | None = None) -> Tuple[bool, Optional[dict], Optional[str]]: """Compute all the parameters related to model_parallel""" - if not is_accelerate_available(): + if not is_package_available("accelerate"): return False, None, None self.num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index eac77e640..63dae247d 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -38,13 +38,13 @@ from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import Doc from lighteval.utils.cache_management import SampleCache, cached -from lighteval.utils.imports import is_vllm_available +from lighteval.utils.imports import is_package_available, requires logger = logging.getLogger(__name__) -if is_vllm_available(): +if is_package_available("vllm"): import ray from more_itertools import distribute from vllm import LLM, RequestOutput, SamplingParams @@ -176,6 +176,7 @@ class VLLMModelConfig(ModelConfig): override_chat_template: bool = None +@requires("vllm") class VLLMModel(LightevalModel): def __init__( self, @@ -526,6 +527,7 @@ def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: raise NotImplementedError() +@requires("vllm") class AsyncVLLMModel(VLLMModel): """VLLM models which deploy async natively (no ray). Supports DP and PP/TP but not batch size > 1""" diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index 71b1efd4a..2bd5399ed 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -42,31 +42,19 @@ from lighteval.tasks.lighteval_task import LightevalTask from lighteval.tasks.registry import Registry from lighteval.tasks.requests import SamplingMethod -from lighteval.utils.imports import ( - NO_ACCELERATE_ERROR_MSG, - NO_NANOTRON_ERROR_MSG, - NO_OPENAI_ERROR_MSG, - NO_SGLANG_ERROR_MSG, - NO_TGI_ERROR_MSG, - NO_VLLM_ERROR_MSG, - is_accelerate_available, - is_nanotron_available, - is_openai_available, - is_sglang_available, - is_tgi_available, - is_vllm_available, -) +from lighteval.utils.imports import is_package_available, raise_if_package_not_available from lighteval.utils.parallelism import test_all_gather from lighteval.utils.utils import make_results_table, remove_reasoning_tags -if is_accelerate_available(): +if is_package_available("accelerate"): from accelerate import Accelerator, InitProcessGroupKwargs else: from unittest.mock import Mock Accelerator = InitProcessGroupKwargs = Mock() -if is_nanotron_available(): + +if is_package_available("nanotron"): from nanotron import distributed as dist from nanotron.parallel.context import ParallelContext @@ -110,23 +98,17 @@ class PipelineParameters: def __post_init__(self): # noqa C901 # Import testing if self.launcher_type == ParallelismManager.ACCELERATE: - if not is_accelerate_available(): - raise ImportError(NO_ACCELERATE_ERROR_MSG) + raise_if_package_not_available("accelerate") elif self.launcher_type == ParallelismManager.VLLM: - if not is_vllm_available(): - raise ImportError(NO_VLLM_ERROR_MSG) + raise_if_package_not_available("vllm") elif self.launcher_type == ParallelismManager.SGLANG: - if not is_sglang_available(): - raise ImportError(NO_SGLANG_ERROR_MSG) + raise_if_package_not_available("sglang") elif self.launcher_type == ParallelismManager.TGI: - if not is_tgi_available(): - raise ImportError(NO_TGI_ERROR_MSG) + raise_if_package_not_available("tgi") elif self.launcher_type == ParallelismManager.NANOTRON: - if not is_nanotron_available(): - raise ImportError(NO_NANOTRON_ERROR_MSG) + raise_if_package_not_available("nanotron") elif self.launcher_type == ParallelismManager.OPENAI: - if not is_openai_available(): - raise ImportError(NO_OPENAI_ERROR_MSG) + raise_if_package_not_available("openai") # Convert reasoning tags to list if needed if not isinstance(self.reasoning_tags, list): @@ -187,12 +169,12 @@ def __init__( def _init_parallelism_manager(self): accelerator, parallel_context = None, None if self.launcher_type == ParallelismManager.ACCELERATE: - if not is_accelerate_available(): + if not is_package_available("accelerate"): raise ValueError("You are trying to launch an accelerate model, but accelerate is not installed") accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))]) test_all_gather(accelerator=accelerator) elif self.launcher_type == ParallelismManager.NANOTRON: - if not is_nanotron_available(): + if not is_package_available("nanotron"): raise ValueError("You are trying to launch a nanotron model, but nanotron is not installed") dist.initialize_torch_distributed() parallel_context = ParallelContext( diff --git a/src/lighteval/tasks/extended/__init__.py b/src/lighteval/tasks/extended/__init__.py index 39963eac1..c81399a9e 100644 --- a/src/lighteval/tasks/extended/__init__.py +++ b/src/lighteval/tasks/extended/__init__.py @@ -20,19 +20,14 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from lighteval.utils.imports import can_load_extended_tasks +import lighteval.tasks.extended.hle.main as hle +import lighteval.tasks.extended.ifeval.main as ifeval +import lighteval.tasks.extended.lcb.main as lcb +import lighteval.tasks.extended.mix_eval.main as mix_eval +import lighteval.tasks.extended.mt_bench.main as mt_bench +import lighteval.tasks.extended.olympiade_bench.main as olympiad_bench +import lighteval.tasks.extended.tiny_benchmarks.main as tiny_benchmarks -if can_load_extended_tasks(): - import lighteval.tasks.extended.hle.main as hle - import lighteval.tasks.extended.ifeval.main as ifeval - import lighteval.tasks.extended.lcb.main as lcb - import lighteval.tasks.extended.mix_eval.main as mix_eval - import lighteval.tasks.extended.mt_bench.main as mt_bench - import lighteval.tasks.extended.olympiade_bench.main as olympiad_bench - import lighteval.tasks.extended.tiny_benchmarks.main as tiny_benchmarks - AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval, olympiad_bench, hle, lcb] - -else: - AVAILABLE_EXTENDED_TASKS_MODULES = [] +AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval, olympiad_bench, hle, lcb] diff --git a/src/lighteval/tasks/extended/ifeval/instructions.py b/src/lighteval/tasks/extended/ifeval/instructions.py index 4022e640f..806125485 100644 --- a/src/lighteval/tasks/extended/ifeval/instructions.py +++ b/src/lighteval/tasks/extended/ifeval/instructions.py @@ -21,7 +21,11 @@ import re import string -import langdetect +from ....utils.imports import is_package_available + + +if is_package_available("langdetect"): + import langdetect import lighteval.tasks.extended.ifeval.instructions_utils as instructions_util diff --git a/src/lighteval/tasks/extended/ifeval/main.py b/src/lighteval/tasks/extended/ifeval/main.py index 1f63f91f1..bde8ae709 100644 --- a/src/lighteval/tasks/extended/ifeval/main.py +++ b/src/lighteval/tasks/extended/ifeval/main.py @@ -31,9 +31,11 @@ from lighteval.models.model_output import ModelResponse from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc, SamplingMethod +from lighteval.utils.imports import requires # Very specific task where there are no precise outputs but instead we test if the format obeys rules +@requires("langdetect") def ifeval_prompt(line, task_name: str = ""): return Doc( task_name=task_name, diff --git a/src/lighteval/tasks/registry.py b/src/lighteval/tasks/registry.py index 01125c778..cd3a76920 100644 --- a/src/lighteval/tasks/registry.py +++ b/src/lighteval/tasks/registry.py @@ -36,12 +36,7 @@ import lighteval.tasks.default_tasks as default_tasks from lighteval.tasks.extended import AVAILABLE_EXTENDED_TASKS_MODULES from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig -from lighteval.utils.imports import ( - CANNOT_USE_EXTENDED_TASKS_MSG, - CANNOT_USE_MULTILINGUAL_TASKS_MSG, - can_load_extended_tasks, - can_load_multilingual_tasks, -) +from lighteval.utils.imports import Extras, raise_if_package_not_available # Import community tasks @@ -122,7 +117,6 @@ def __init__( tasks: str | Path | None = None, custom_tasks: str | Path | ModuleType | None = None, load_community: bool = False, - load_extended: bool = False, load_multilingual: bool = False, ): """ @@ -158,17 +152,19 @@ def __init__( logger.warning( "You passed no task name. This should only occur if you are using the CLI to inspect tasks." ) - self.tasks_list = [] + tasks = [] else: - self.tasks_list = self._get_full_task_list_from_input_string(tasks) + tasks = self._get_full_task_list_from_input_string(tasks) # These parameters are dynamically set by the task names provided, thanks to `activate_suites_to_load`, # except in the `tasks` CLI command to display the full list self._load_community = load_community - self._load_extended = load_extended self._load_multilingual = load_multilingual - self._activate_loading_of_optional_suite() # we dynamically set the loading parameters - # We load all task to + # Sanitize tasks by inferring suites/few shots when not specified + self.tasks_list = self._sanitize_tasks_list(tasks) + + # Loads all the available tasks + self._activate_loading_of_optional_suite() # we dynamically set the loading parameters self._task_registry = self._load_full_registry() self.task_to_configs = self._update_task_configs() @@ -219,13 +215,8 @@ def _activate_loading_of_optional_suite(self) -> None: f"Suite {suite_name} unknown. This is not normal, unless you are testing adding new evaluations." ) - if "extended" in suites: - if not can_load_extended_tasks(): - raise ImportError(CANNOT_USE_EXTENDED_TASKS_MSG) - self._load_extended = True if "multilingual" in suites: - if not can_load_multilingual_tasks(): - raise ImportError(CANNOT_USE_MULTILINGUAL_TASKS_MSG) + raise_if_package_not_available(Extras.MULTILINGUAL) self._load_multilingual = True if "community" in suites: self._load_community = True @@ -248,11 +239,8 @@ def _load_full_registry(self) -> dict[str, LightevalTaskConfig]: custom_tasks_module.append(Registry.create_custom_tasks_module(custom_tasks=self._custom_tasks)) # Need to load extended tasks - if self._load_extended: - for extended_task_module in AVAILABLE_EXTENDED_TASKS_MODULES: - custom_tasks_module.append(extended_task_module) - else: - logger.warning(CANNOT_USE_EXTENDED_TASKS_MSG) + for extended_task_module in AVAILABLE_EXTENDED_TASKS_MODULES: + custom_tasks_module.append(extended_task_module) # Need to load community tasks if self._load_community: @@ -285,6 +273,65 @@ def _load_full_registry(self) -> dict[str, LightevalTaskConfig]: return {**default_tasks_registry, **custom_tasks_registry} + @lru_cache() + def _get_suite_from_task(self): + registry = self._load_full_registry() + task_name_to_suite = {} + for task in registry.keys(): + suite, task_name = task.rsplit("|", 1) + if task_name not in task_name_to_suite: + task_name_to_suite[task_name] = [suite] + else: + task_name_to_suite[task_name].append(suite) + + return task_name_to_suite + + def _infer_suite_name_from_task(self, taskname): + suite_from_task = self._get_suite_from_task() + + if taskname not in suite_from_task: + raise ValueError(f"Requested task {taskname} is not available") + + if len(suite_from_task[taskname]) > 1: + raise ValueError(f"More than one suite available for task {taskname}: {suite_from_task[taskname]}") + + else: + return suite_from_task[taskname][0] + + def _sanitize_tasks_list(self, tasks): + tasks_list = [] + for task in tasks: + try: + if task.count("|") == 3: + logger.warning( + "Deprecation warning: You provided 4 arguments in your task name, but we no longer support the `truncate_fewshot` option. We will ignore the parameter for now, but it will fail in a couple of versions, so you should change your task name to `suite|task` or `suit|task|num_fewshot`." + ) + suite_name, task_name, few_shot, _ = tuple(task.split("|")) + elif task.count("|") == 2: + suite_name, task_name, few_shot = tuple(task.split("|")) + elif task.count("|") == 1: + arg0, arg1 = tuple(task.split("|")) + + if arg1.isdigit(): + suite_name = self._infer_suite_name_from_task(task) + task_name, few_shot = arg0, arg1 + else: + suite_name, task_name = arg0, arg1 + few_shot = "0" + elif task.count("|") == 0: + suite_name = self._infer_suite_name_from_task(task) + task_name = task + few_shot = "0" + else: + raise ValueError( + f"Cannot get task info from {task}. The correct format is:\n- task\n- suite|task\n- suite|task|few_shot" + ) + except ValueError: + raise ValueError(f"Cannot get task info from {task}. correct format is suite|task|few_shot") + + tasks_list.append("|".join([suite_name, task_name, few_shot])) + return tasks_list + def _update_task_configs(self) -> dict[str, LightevalTaskConfig]: # noqa: C901 """ Updates each config depending on the input tasks (we replace all provided params, like few shot number, sampling params, etc) @@ -297,17 +344,33 @@ def _update_task_configs(self) -> dict[str, LightevalTaskConfig]: # noqa: C901 try: if task.count("|") == 3: logger.warning( - "Deprecation warning: You provided 4 arguments in your task name, but we no longer support the `truncate_fewshot` option. We will ignore the parameter for now, but it will fail in a couple of versions, so you should change your task name to `suite|task|num_fewshot`." + "Deprecation warning: You provided 4 arguments in your task name, but we no longer support the `truncate_fewshot` option. We will ignore the parameter for now, but it will fail in a couple of versions, so you should change your task name to `suite|task` or `suit|task|num_fewshot`." ) suite_name, task_name, few_shot, _ = tuple(task.split("|")) - else: + elif task.count("|") == 2: suite_name, task_name, few_shot = tuple(task.split("|")) + elif task.count("|") == 1: + suite_name, task_name = tuple(task.split("|")) + few_shot = 0 + elif task.count("|") == 0: + suite_name = None + task_name = task + few_shot = 0 + else: + raise ValueError( + f"Cannot get task info from {task}. The correct format is:\n- task\n- suite|task\n- suite|task|few_shot" + ) + if "@" in task_name: split_task_name = task_name.split("@") task_name, metric_params = split_task_name[0], split_task_name[1:] # We convert k:v to {"k": "v"}, then to correct type metric_params_dict = dict(item.split("=") for item in metric_params if item) metric_params_dict = {k: ast.literal_eval(v) for k, v in metric_params_dict.items()} + + if suite_name is None: + suite_name = self._infer_suite_name_from_task(task) + few_shot = int(few_shot) except ValueError: diff --git a/src/lighteval/utils/imports.py b/src/lighteval/utils/imports.py index 2534cb52a..f9ac5598a 100644 --- a/src/lighteval/utils/imports.py +++ b/src/lighteval/utils/imports.py @@ -11,117 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import importlib +from functools import lru_cache -def is_accelerate_available() -> bool: - return importlib.util.find_spec("accelerate") is not None - - -NO_ACCELERATE_ERROR_MSG = "You requested the use of accelerate for this evaluation, but it is not available in your current environement. Please install it using pip." - - -def is_tgi_available() -> bool: - return importlib.util.find_spec("text_generation") is not None - - -NO_TGI_ERROR_MSG = "You are trying to start a text generation inference endpoint, but text-generation is not present in your local environement. Please install it using pip." - - -def is_nanotron_available() -> bool: - return importlib.util.find_spec("nanotron") is not None - - -NO_NANOTRON_ERROR_MSG = "You requested the use of nanotron for this evaluation, but it is not available in your current environement. Please install it using pip." - - -def is_optimum_available() -> bool: - return importlib.util.find_spec("optimum") is not None - - -def is_bnb_available() -> bool: - return importlib.util.find_spec("bitsandbytes") is not None - - -NO_BNB_ERROR_MSG = "You are trying to load a model quantized with `bitsandbytes`, which is not available in your local environement. Please install it using pip." - - -def is_autogptq_available() -> bool: - return importlib.util.find_spec("auto_gptq") is not None - - -NO_AUTOGPTQ_ERROR_MSG = "You are trying to load a model quantized with `auto-gptq`, which is not available in your local environement. Please install it using pip." - - -def is_peft_available() -> bool: - return importlib.util.find_spec("peft") is not None - - -NO_PEFT_ERROR_MSG = "You are trying to use adapter weights models, for which you need `peft`, which is not available in your environment. Please install it using pip." - - -def is_tensorboardX_available() -> bool: - return importlib.util.find_spec("tensorboardX") is not None - - -NO_TENSORBOARDX_WARN_MSG = ( - "You are trying to log using tensorboardX, which is not installed. Please install it using pip. Skipping." -) - - -def is_openai_available() -> bool: - return importlib.util.find_spec("openai") is not None - +class Extras(enum.Enum): + MULTILINGUAL = "multilingual" + EXTENDED = "extended" -NO_OPENAI_ERROR_MSG = "You are trying to use an Open AI LLM as a judge, for which you need `openai`, which is not available in your environment. Please install it using pip." +@lru_cache() +def is_package_available(package_name: str): + if package_name == Extras.MULTILINGUAL: + return all(importlib.util.find_spec(package) is not None for package in ["stanza", "spacy", "langcodes"]) + if package_name == Extras.EXTENDED: + return all(importlib.util.find_spec(package) is not None for package in ["spacy"]) + else: + return importlib.util.find_spec(package_name) is not None -def is_litellm_available() -> bool: - return importlib.util.find_spec("litellm") is not None - -NO_LITELLM_ERROR_MSG = "You are trying to use a LiteLLM model, for which you need `litellm`, which is not available in your environment. Please install it using pip." - - -def is_vllm_available() -> bool: - return importlib.util.find_spec("vllm") is not None and importlib.util.find_spec("ray") is not None - - -NO_VLLM_ERROR_MSG = "You are trying to use an VLLM model, for which you need `vllm` and `ray`, which are not available in your environment. Please install them using pip, `pip install vllm ray`." - - -def is_sglang_available() -> bool: - return importlib.util.find_spec("sglang") is not None and importlib.util.find_spec("flashinfer") is not None - - -NO_SGLANG_ERROR_MSG = "You are trying to use an sglang model, for which you need `sglang` and `flashinfer`, which are not available in your environment. Please install them using pip, `pip install vllm ray`." - - -def can_load_extended_tasks() -> bool: - imports = [] - for package in ["langdetect", "openai"]: - imports.append(importlib.util.find_spec(package)) - - return all(cur_import is not None for cur_import in imports) - - -CANNOT_USE_EXTENDED_TASKS_MSG = "If you want to use extended_tasks, make sure you installed their dependencies using `pip install -e .[extended_tasks]`." - - -def can_load_multilingual_tasks() -> bool: - try: - import lighteval.tasks.multilingual.tasks # noqa: F401 - - return True - except ImportError: - return False - - -CANNOT_USE_MULTILINGUAL_TASKS_MSG = "If you want to use multilingual tasks, make sure you installed their dependencies using `pip install -e .[multilingual]`." - - -def can_load_spacy_tokenizer(language: str) -> bool: +@lru_cache() +def is_multilingual_package_available(language: str): imports = [] packages = ["spacy", "stanza"] if language == "vi": @@ -131,38 +42,41 @@ def can_load_spacy_tokenizer(language: str) -> bool: for package in packages: imports.append(importlib.util.find_spec(package)) - return all(cur_import is not None for cur_import in imports) - - -NO_SPACY_TOKENIZER_ERROR_MSG = "You are trying to load a spacy tokenizer, for which you need `spacy` and its dependencies, which are not available in your environment. Please install them using `pip install lighteval[multilingual]`." - -def can_load_stanza_tokenizer() -> bool: - return importlib.util.find_spec("stanza") is not None - - -NO_STANZA_TOKENIZER_ERROR_MSG = "You are trying to load a stanza tokenizer, for which you need `stanza`, which is not available in your environment. Please install it using `pip install lighteval[multilingual]`." + return all(cur_import is not None for cur_import in imports) -# Better than having to check import every time -def requires_latex2sympy2_extended(func): - checked_import = False +def raise_if_package_not_available(package_name: str | Extras, *, language: str = None): + if package_name == Extras.MULTILINGUAL and not is_multilingual_package_available(language): + raise ImportError(not_installed_error_message(package_name)) - def wrapper(*args, **kwargs): - nonlocal checked_import - if not checked_import and importlib.util.find_spec("latex2sympy2_extended") is None: - raise ImportError(NO_LATEX2SYMPY2_EXTENDED_ERROR_MSG) - checked_import = True - return func(*args, **kwargs) + if not is_package_available(package_name): + raise ImportError(not_installed_error_message(package_name)) - return wrapper +def not_installed_error_message(package_name: str | Extras) -> str: + if package_name == Extras.MULTILINGUAL: + return "You are trying to run an evaluation requiring multilingual capabilities. Please install the required extra: `pip install lighteval[multilingual]`" + elif package_name == Extras.EXTENDED: + return "You are trying to run an evaluation requiring additional extensions. Please install the required extra: `pip install lighteval[extended] " + elif package_name == "text_generation": + return "You are trying to start a text generation inference endpoint, but TGI is not present in your local environement. Please install it using pip." + elif package_name in ["bitsandbytes", "auto-gptq"]: + return f"You are trying to load a model quantized with `{package_name}`, which is not available in your local environement. Please install it using pip." + elif package_name == "peft": + return "You are trying to use adapter weights models, for which you need `peft`, which is not available in your environment. Please install it using pip." + elif package_name == "openai": + return "You are trying to use an Open AI LLM as a judge, for which you need `openai`, which is not available in your environment. Please install it using pip." -NO_LATEX2SYMPY2_EXTENDED_ERROR_MSG = "You are trying to parse latex expressions, for which you need `latex2sympy2_extended`, which is not available in your environment. Please install it using `pip install lighteval[math]`." + return f"You requested the use of `{package_name}` for this evaluation, but it is not available in your current environement. Please install it using pip." -def is_spacy_available() -> bool: - return importlib.util.find_spec("spacy") is not None +def requires(package_name): + def decorator(func): + def wrapper(*args, **kwargs): + raise_if_package_not_available(package_name) + return func(*args, **kwargs) + return wrapper -NO_SPACY_ERROR_MSG = "You are trying to use some metrics requiring `spacy`, which is not available in your environment. Please install it using pip." + return decorator diff --git a/src/lighteval/utils/parallelism.py b/src/lighteval/utils/parallelism.py index 2e73f4c73..896183160 100644 --- a/src/lighteval/utils/parallelism.py +++ b/src/lighteval/utils/parallelism.py @@ -27,12 +27,7 @@ import torch -from lighteval.utils.imports import ( - NO_ACCELERATE_ERROR_MSG, - NO_NANOTRON_ERROR_MSG, - is_accelerate_available, - is_nanotron_available, -) +from lighteval.utils.imports import raise_if_package_not_available logger = logging.getLogger(__name__) @@ -131,16 +126,14 @@ def test_all_gather(accelerator=None, parallel_context=None): ImportError: If the required accelerator or parallel context is not available. """ if accelerator: - if not is_accelerate_available(): - raise ImportError(NO_ACCELERATE_ERROR_MSG) + raise_if_package_not_available("accelerate") logger.info("Test gather tensor") test_tensor: torch.Tensor = torch.tensor([accelerator.process_index], device=accelerator.device) gathered_tensor: torch.Tensor = accelerator.gather(test_tensor) logger.info(f"gathered_tensor {gathered_tensor}, should be {list(range(accelerator.num_processes))}") accelerator.wait_for_everyone() elif parallel_context: - if not is_nanotron_available(): - raise ImportError(NO_NANOTRON_ERROR_MSG) + raise_if_package_not_available("nanotron") from nanotron import distributed as dist from nanotron import logging diff --git a/tests/pipeline/test_reasoning_tags.py b/tests/pipeline/test_reasoning_tags.py index 84dfb9e7e..f772970c4 100644 --- a/tests/pipeline/test_reasoning_tags.py +++ b/tests/pipeline/test_reasoning_tags.py @@ -35,7 +35,7 @@ from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig from lighteval.tasks.registry import Registry from lighteval.tasks.requests import Doc, SamplingMethod -from lighteval.utils.imports import is_accelerate_available +from lighteval.utils.imports import is_package_available class TestPipelineReasoningTags(unittest.TestCase): @@ -129,7 +129,7 @@ def test_remove_reasoning_tags_enabled(self): ) # Initialize accelerator if available - if is_accelerate_available(): + if is_package_available("accelerate"): from accelerate import Accelerator Accelerator() @@ -175,7 +175,7 @@ def test_remove_reasoning_tags_enabled_tags_as_string(self): ) # Initialize accelerator if available - if is_accelerate_available(): + if is_package_available("accelerate"): from accelerate import Accelerator Accelerator() @@ -221,7 +221,7 @@ def test_remove_reasoning_tags_enabled_default_tags(self): ) # Initialize accelerator if available - if is_accelerate_available(): + if is_package_available("accelerate"): from accelerate import Accelerator Accelerator() @@ -264,7 +264,7 @@ def test_remove_reasoning_tags_disabled(self): ) # Initialize accelerator if available - if is_accelerate_available(): + if is_package_available("accelerate"): from accelerate import Accelerator Accelerator() @@ -310,7 +310,7 @@ def test_custom_reasoning_tags(self): ) # Initialize accelerator if available - if is_accelerate_available(): + if is_package_available("accelerate"): from accelerate import Accelerator Accelerator() @@ -356,7 +356,7 @@ def test_multiple_reasoning_tags(self): ) # Initialize accelerator if available - if is_accelerate_available(): + if is_package_available("accelerate"): from accelerate import Accelerator Accelerator() diff --git a/tests/tasks/test_registry.py b/tests/tasks/test_registry.py index 106708549..775c9bbda 100644 --- a/tests/tasks/test_registry.py +++ b/tests/tasks/test_registry.py @@ -167,3 +167,32 @@ def test_task_creation(): assert isinstance(task, LightevalTask) assert task.name == "storycloze:2016" + + +def test_fewshot_can_be_inferred(): + """ + Tests that fewshot can be inferred without being explicitely specified. + """ + registry = Registry(tasks="lighteval|storycloze:2016") + tasks = registry.load_tasks() + + assert "lighteval|storycloze:2016|0" in tasks + assert registry.task_to_configs["lighteval|storycloze:2016"][0].num_fewshots == 0 + + +def test_suite_can_be_inferred(): + """ + Tests that tasks registry correctly creates tasks + """ + registry = Registry(tasks="storycloze:2016") + tasks = registry.load_tasks() + assert "lighteval|storycloze:2016|0" in tasks + + +def test_multilingual_suite_can_be_inferred(): + """ + Tests that tasks registry correctly creates tasks + """ + registry = Registry(tasks="indicnxnli_tam_hybrid", load_multilingual=True) + tasks = registry.load_tasks() + assert "lighteval|indicnxnli_tam_hybrid|0" in tasks diff --git a/tests/utils.py b/tests/utils.py index b44d27551..138d282b3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -34,7 +34,7 @@ from lighteval.tasks.lighteval_task import LightevalTask from lighteval.tasks.registry import Registry from lighteval.tasks.requests import Doc -from lighteval.utils.imports import is_accelerate_available +from lighteval.utils.imports import is_package_available class FakeModelConfig(ModelConfig): @@ -115,7 +115,7 @@ def load_tasks(self): # This is due to logger complaining we have no initialised the accelerator # It's hard to mock as it's global singleton - if is_accelerate_available(): + if is_package_available("accelerate"): from accelerate import Accelerator Accelerator() diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 1d8f6060d..0add6edf2 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -224,9 +224,9 @@ def test_cache_vllm(self, mock_create_model, mock_greedy_until, mock_loglikeliho @patch("lighteval.models.endpoints.tgi_model.ModelClient._loglikelihood") def test_cache_tgi(self, mock_loglikelihood, mock_greedy_until, mock_requests_get): from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig - from lighteval.utils.imports import is_tgi_available + from lighteval.utils.imports import is_package_available - if not is_tgi_available(): + if not is_package_available("tgi"): pytest.skip("Skipping because missing the imports") # Mock TGI requests