Skip to content

Commit bd04ddb

Browse files
committed
v2
1 parent 2a7b29a commit bd04ddb

File tree

8 files changed

+108
-43
lines changed

8 files changed

+108
-43
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ dependencies = [
8484
"fsspec>=2023.12.2",
8585
"httpx>=0.27.2",
8686
"latex2sympy2_extended==1.0.6",
87+
"langcodes"
8788
]
8889

8990
[project.optional-dependencies]

src/lighteval/metrics/normalizations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,9 @@ def get_multilingual_normalizer(lang: Language, lower: bool = True) -> Callable[
450450
Returns:
451451
Callable[[str], str]: A function that normalizes text for the specified language
452452
"""
453-
tokenizer = get_word_tokenizer(lang)
454453

455454
def _inner_normalizer(text: str) -> str:
455+
tokenizer = get_word_tokenizer(lang)
456456
text = remove_articles(text, lang)
457457
text = remove_punc(text)
458458
if lower:

src/lighteval/models/model_loader.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig
4444
from lighteval.models.transformers.vlm_transformers_model import VLMTransformersModel, VLMTransformersModelConfig
4545
from lighteval.models.vllm.vllm_model import AsyncVLLMModel, VLLMModel, VLLMModelConfig
46-
from lighteval.utils.imports import raise_if_package_not_available, requires
46+
from lighteval.utils.imports import raise_if_package_not_available
4747

4848

4949
logger = logging.getLogger(__name__)
@@ -92,14 +92,12 @@ def load_model( # noqa: C901
9292
return load_inference_providers_model(config=config)
9393

9494

95-
@requires("tgi")
9695
def load_model_with_tgi(config: TGIModelConfig):
9796
logger.info(f"Load model from inference server: {config.inference_server_address}")
9897
model = ModelClient(config=config)
9998
return model
10099

101100

102-
@requires("litellm")
103101
def load_litellm_model(config: LiteLLMModelConfig):
104102
model = LiteLLMClient(config)
105103
return model
@@ -171,6 +169,5 @@ def load_inference_providers_model(config: InferenceProvidersModelConfig):
171169
return InferenceProvidersClient(config=config)
172170

173171

174-
@requires("sglang")
175172
def load_sglang_model(config: SGLangModelConfig):
176173
return SGLangModel(config=config)

src/lighteval/tasks/extended/ifeval/instructions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import re
2222
import string
2323

24-
from ....utils.imports import is_package_available
24+
from lighteval.utils.imports import is_package_available
2525

2626

2727
if is_package_available("langdetect"):

src/lighteval/tasks/extended/ifeval/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict:
125125
}
126126

127127

128+
@requires("langdetect")
128129
def agg_inst_level_acc(items):
129130
flat_items = [item for sublist in items for item in sublist]
130131
inst_level_acc = sum(flat_items) / len(flat_items)

src/lighteval/tasks/registry.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@
3636
import lighteval.tasks.default_tasks as default_tasks
3737
from lighteval.tasks.extended import AVAILABLE_EXTENDED_TASKS_MODULES
3838
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig
39-
from lighteval.utils.imports import (
40-
CANNOT_USE_EXTENDED_TASKS_MSG,
41-
CANNOT_USE_MULTILINGUAL_TASKS_MSG,
42-
can_load_extended_tasks,
43-
can_load_multilingual_tasks,
44-
)
4539

4640

4741
# Import community tasks
@@ -121,9 +115,9 @@ def __init__(
121115
self,
122116
tasks: str | Path | None = None,
123117
custom_tasks: str | Path | ModuleType | None = None,
124-
load_community: bool = False,
125-
load_extended: bool = False,
126-
load_multilingual: bool = False,
118+
load_community: bool = True,
119+
load_extended: bool = True,
120+
load_multilingual: bool = True,
127121
):
128122
"""
129123
Initialize the Registry class.
@@ -219,17 +213,6 @@ def _activate_loading_of_optional_suite(self) -> None:
219213
f"Suite {suite_name} unknown. This is not normal, unless you are testing adding new evaluations."
220214
)
221215

222-
if "extended" in suites:
223-
if not can_load_extended_tasks():
224-
raise ImportError(CANNOT_USE_EXTENDED_TASKS_MSG)
225-
self._load_extended = True
226-
if "multilingual" in suites:
227-
if not can_load_multilingual_tasks():
228-
raise ImportError(CANNOT_USE_MULTILINGUAL_TASKS_MSG)
229-
self._load_multilingual = True
230-
if "community" in suites:
231-
self._load_community = True
232-
233216
def _load_full_registry(self) -> dict[str, LightevalTaskConfig]:
234217
"""
235218
Returns:
@@ -251,8 +234,6 @@ def _load_full_registry(self) -> dict[str, LightevalTaskConfig]:
251234
if self._load_extended:
252235
for extended_task_module in AVAILABLE_EXTENDED_TASKS_MODULES:
253236
custom_tasks_module.append(extended_task_module)
254-
else:
255-
logger.warning(CANNOT_USE_EXTENDED_TASKS_MSG)
256237

257238
# Need to load community tasks
258239
if self._load_community:

src/lighteval/utils/imports.py

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import enum
15+
import functools
1516
import importlib
17+
import inspect
18+
import operator
19+
import re
20+
from enum import Enum
1621
from functools import lru_cache
22+
from typing import Callable
23+
24+
from packaging.requirements import Requirement
25+
from packaging.version import Version
1726

1827

1928
class Extras(enum.Enum):
@@ -22,9 +31,9 @@ class Extras(enum.Enum):
2231

2332

2433
@lru_cache()
25-
def is_package_available(package_name: str):
34+
def is_package_available(package_name: str | Extras):
2635
if package_name == Extras.MULTILINGUAL:
27-
return all(importlib.util.find_spec(package) is not None for package in ["stanza", "spacy", "langcodes"])
36+
return all(importlib.util.find_spec(package) is not None for package in ["stanza", "spacy"])
2837
if package_name == Extras.EXTENDED:
2938
return all(importlib.util.find_spec(package) is not None for package in ["spacy"])
3039
else:
@@ -46,12 +55,14 @@ def is_multilingual_package_available(language: str):
4655
return all(cur_import is not None for cur_import in imports)
4756

4857

49-
def raise_if_package_not_available(package_name: str | Extras, *, language: str = None):
58+
def raise_if_package_not_available(package_name: str | Extras, *, language: str = None, object_name: str = None):
59+
prefix = "You" if object_name is None else f"Through the use of {object_name}, you"
60+
5061
if package_name == Extras.MULTILINGUAL and not is_multilingual_package_available(language):
51-
raise ImportError(not_installed_error_message(package_name))
62+
raise ImportError(prefix + not_installed_error_message(package_name)[3:])
5263

5364
if not is_package_available(package_name):
54-
raise ImportError(not_installed_error_message(package_name))
65+
raise ImportError(prefix + not_installed_error_message(package_name)[3:])
5566

5667

5768
def not_installed_error_message(package_name: str | Extras) -> str:
@@ -71,12 +82,89 @@ def not_installed_error_message(package_name: str | Extras) -> str:
7182
return f"You requested the use of `{package_name}` for this evaluation, but it is not available in your current environement. Please install it using pip."
7283

7384

74-
def requires(package_name):
75-
def decorator(func):
76-
def wrapper(*args, **kwargs):
77-
raise_if_package_not_available(package_name)
78-
return func(*args, **kwargs)
85+
class DummyObject(type):
86+
"""
87+
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
88+
`requires_backend` each time a user tries to access any method of that class.
89+
"""
90+
91+
is_dummy = True
92+
93+
def __getattribute__(cls, key):
94+
if (key.startswith("_") and key != "_from_config") or key == "is_dummy" or key == "mro" or key == "call":
95+
return super().__getattribute__(key)
96+
97+
for backend in cls._backends:
98+
raise_if_package_not_available(backend)
99+
100+
101+
class VersionComparison(Enum):
102+
EQUAL = operator.eq
103+
NOT_EQUAL = operator.ne
104+
GREATER_THAN = operator.gt
105+
LESS_THAN = operator.lt
106+
GREATER_THAN_OR_EQUAL = operator.ge
107+
LESS_THAN_OR_EQUAL = operator.le
108+
109+
@staticmethod
110+
def from_string(version_string: str) -> Callable[[int | Version, int | Version], bool]:
111+
string_to_operator = {
112+
"=": VersionComparison.EQUAL.value,
113+
"==": VersionComparison.EQUAL.value,
114+
"!=": VersionComparison.NOT_EQUAL.value,
115+
">": VersionComparison.GREATER_THAN.value,
116+
"<": VersionComparison.LESS_THAN.value,
117+
">=": VersionComparison.GREATER_THAN_OR_EQUAL.value,
118+
"<=": VersionComparison.LESS_THAN_OR_EQUAL.value,
119+
}
120+
121+
return string_to_operator[version_string]
122+
123+
124+
@lru_cache
125+
def split_package_version(package_version_str) -> tuple[str, str, str]:
126+
pattern = r"([a-zA-Z0-9_-]+)([!<>=~]+)([0-9.]+)"
127+
match = re.match(pattern, package_version_str)
128+
if match:
129+
return (match.group(1), match.group(2), match.group(3))
130+
else:
131+
raise ValueError(f"Invalid package version string: {package_version_str}")
132+
133+
134+
def requires(*backends):
135+
"""
136+
Decorator to raise an ImportError if the decorated object (function or class) requires a dependency
137+
which is not installed.
138+
"""
139+
140+
applied_backends = []
141+
for backend in backends:
142+
applied_backends.append(Requirement(backend.value if isinstance(backend, Extras) else backend))
143+
144+
def inner_fn(_object):
145+
_object._backends = applied_backends
146+
147+
if inspect.isclass(_object):
148+
149+
class Placeholder(metaclass=DummyObject):
150+
_backends = applied_backends
151+
152+
def __init__(self, *args, **kwargs):
153+
for backend in self._backends:
154+
raise_if_package_not_available(backend.name, object_name=_object.__class__.__name__)
155+
156+
Placeholder.__name__ = _object.__name__
157+
Placeholder.__module__ = _object.__module__
158+
159+
return Placeholder
160+
else:
161+
162+
@functools.wraps(_object)
163+
def wrapper(*args, **kwargs):
164+
for backend in _object._backends:
165+
raise_if_package_not_available(backend.name, object_name=_object.__name__)
166+
return _object(*args, **kwargs)
79167

80-
return wrapper
168+
return wrapper
81169

82-
return decorator
170+
return inner_fn

src/lighteval/utils/parallelism.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,6 @@ def test_all_gather(accelerator=None, parallel_context=None):
121121
Args:
122122
accelerator (Optional): The accelerator object used for parallelism.
123123
parallel_context (Optional): The parallel context object used for parallelism.
124-
125-
Raises:
126-
ImportError: If the required accelerator or parallel context is not available.
127124
"""
128125
if accelerator:
129126
raise_if_package_not_available("accelerate")

0 commit comments

Comments
 (0)