1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import enum
15+ import functools
1516import importlib
17+ import inspect
18+ import operator
19+ import re
20+ from enum import Enum
1621from functools import lru_cache
22+ from typing import Callable
23+
24+ from packaging .requirements import Requirement
25+ from packaging .version import Version
1726
1827
1928class Extras (enum .Enum ):
@@ -22,9 +31,9 @@ class Extras(enum.Enum):
2231
2332
2433@lru_cache ()
25- def is_package_available (package_name : str ):
34+ def is_package_available (package_name : str | Extras ):
2635 if package_name == Extras .MULTILINGUAL :
27- return all (importlib .util .find_spec (package ) is not None for package in ["stanza" , "spacy" , "langcodes" ])
36+ return all (importlib .util .find_spec (package ) is not None for package in ["stanza" , "spacy" ])
2837 if package_name == Extras .EXTENDED :
2938 return all (importlib .util .find_spec (package ) is not None for package in ["spacy" ])
3039 else :
@@ -46,12 +55,14 @@ def is_multilingual_package_available(language: str):
4655 return all (cur_import is not None for cur_import in imports )
4756
4857
49- def raise_if_package_not_available (package_name : str | Extras , * , language : str = None ):
58+ def raise_if_package_not_available (package_name : str | Extras , * , language : str = None , object_name : str = None ):
59+ prefix = "You" if object_name is None else f"Through the use of { object_name } , you"
60+
5061 if package_name == Extras .MULTILINGUAL and not is_multilingual_package_available (language ):
51- raise ImportError (not_installed_error_message (package_name ))
62+ raise ImportError (prefix + not_installed_error_message (package_name )[ 3 :] )
5263
5364 if not is_package_available (package_name ):
54- raise ImportError (not_installed_error_message (package_name ))
65+ raise ImportError (prefix + not_installed_error_message (package_name )[ 3 :] )
5566
5667
5768def not_installed_error_message (package_name : str | Extras ) -> str :
@@ -71,12 +82,89 @@ def not_installed_error_message(package_name: str | Extras) -> str:
7182 return f"You requested the use of `{ package_name } ` for this evaluation, but it is not available in your current environement. Please install it using pip."
7283
7384
74- def requires (package_name ):
75- def decorator (func ):
76- def wrapper (* args , ** kwargs ):
77- raise_if_package_not_available (package_name )
78- return func (* args , ** kwargs )
85+ class DummyObject (type ):
86+ """
87+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
88+ `requires_backend` each time a user tries to access any method of that class.
89+ """
90+
91+ is_dummy = True
92+
93+ def __getattribute__ (cls , key ):
94+ if (key .startswith ("_" ) and key != "_from_config" ) or key == "is_dummy" or key == "mro" or key == "call" :
95+ return super ().__getattribute__ (key )
96+
97+ for backend in cls ._backends :
98+ raise_if_package_not_available (backend )
99+
100+
101+ class VersionComparison (Enum ):
102+ EQUAL = operator .eq
103+ NOT_EQUAL = operator .ne
104+ GREATER_THAN = operator .gt
105+ LESS_THAN = operator .lt
106+ GREATER_THAN_OR_EQUAL = operator .ge
107+ LESS_THAN_OR_EQUAL = operator .le
108+
109+ @staticmethod
110+ def from_string (version_string : str ) -> Callable [[int | Version , int | Version ], bool ]:
111+ string_to_operator = {
112+ "=" : VersionComparison .EQUAL .value ,
113+ "==" : VersionComparison .EQUAL .value ,
114+ "!=" : VersionComparison .NOT_EQUAL .value ,
115+ ">" : VersionComparison .GREATER_THAN .value ,
116+ "<" : VersionComparison .LESS_THAN .value ,
117+ ">=" : VersionComparison .GREATER_THAN_OR_EQUAL .value ,
118+ "<=" : VersionComparison .LESS_THAN_OR_EQUAL .value ,
119+ }
120+
121+ return string_to_operator [version_string ]
122+
123+
124+ @lru_cache
125+ def split_package_version (package_version_str ) -> tuple [str , str , str ]:
126+ pattern = r"([a-zA-Z0-9_-]+)([!<>=~]+)([0-9.]+)"
127+ match = re .match (pattern , package_version_str )
128+ if match :
129+ return (match .group (1 ), match .group (2 ), match .group (3 ))
130+ else :
131+ raise ValueError (f"Invalid package version string: { package_version_str } " )
132+
133+
134+ def requires (* backends ):
135+ """
136+ Decorator to raise an ImportError if the decorated object (function or class) requires a dependency
137+ which is not installed.
138+ """
139+
140+ applied_backends = []
141+ for backend in backends :
142+ applied_backends .append (Requirement (backend .value if isinstance (backend , Extras ) else backend ))
143+
144+ def inner_fn (_object ):
145+ _object ._backends = applied_backends
146+
147+ if inspect .isclass (_object ):
148+
149+ class Placeholder (metaclass = DummyObject ):
150+ _backends = applied_backends
151+
152+ def __init__ (self , * args , ** kwargs ):
153+ for backend in self ._backends :
154+ raise_if_package_not_available (backend .name , object_name = _object .__class__ .__name__ )
155+
156+ Placeholder .__name__ = _object .__name__
157+ Placeholder .__module__ = _object .__module__
158+
159+ return Placeholder
160+ else :
161+
162+ @functools .wraps (_object )
163+ def wrapper (* args , ** kwargs ):
164+ for backend in _object ._backends :
165+ raise_if_package_not_available (backend .name , object_name = _object .__name__ )
166+ return _object (* args , ** kwargs )
79167
80- return wrapper
168+ return wrapper
81169
82- return decorator
170+ return inner_fn
0 commit comments