Skip to content

Commit 47c271f

Browse files
authored
Add condition to ModuleType (#223)
* Add and test condition * Fix RTD config
1 parent d51c377 commit 47c271f

File tree

4 files changed

+61
-5
lines changed

4 files changed

+61
-5
lines changed

.readthedocs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ python:
1717
sphinx:
1818
builder: html
1919
fail_on_warning: true
20+
configuration: docs/conf.py

docs/types.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,32 @@ plum.type.ModuleType[tensorflow.python.framework.ops.EagerTensor]
226226
tensorflow.python.framework.ops.EagerTensor
227227
```
228228

229+
You might run into a scenario where an import is only possible when a certain condition
230+
is satisfied, e.g. a constraint on the package version.
231+
You can specify a condition with the keyword argument `condition`.
232+
233+
Example:
234+
235+
```python
236+
>>> def jax_version():
237+
... import sys
238+
... version_string = sys.modules["jax.version"].__version__
239+
... return tuple(int(x) for x in version_string.split("."))
240+
241+
>>> ArrayImpl = Union[
242+
... ModuleType(
243+
... "jaxlib.xla_extension",
244+
... "ArrayImpl",
245+
... condition=lambda: jax_version() < (0, 6, 0),
246+
... ),
247+
... ModuleType(
248+
... "jaxlib._jax",
249+
... "ArrayImpl",
250+
... condition=lambda: jax_version() >= (0, 6, 0),
251+
... ),
252+
... ]
253+
```
254+
229255
% skip: end
230256

231257
(promisedtype)=

plum/type.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import sys
33
import typing
44
import warnings
5-
from collections.abc import Hashable
6-
from typing import Union, get_args, get_origin
5+
from collections.abc import Callable, Hashable
6+
from typing import Optional, Union, get_args, get_origin
77

88
from typing_extensions import Self, TypeGuard
99

@@ -91,17 +91,34 @@ class ModuleType(ResolvableType):
9191
name (str): Name of the type that is promised.
9292
allow_fail (bool, optional): If the type is does not exist in `module`,
9393
do not raise an `AttributeError`.
94+
condition (Callable[[], bool], optional): A callable that can check a condition,
95+
like a package version. This callable will be run whenever `module` has been
96+
imported. Only if the callable returns `True`, `name` will be imported
97+
from `module`.
9498
"""
9599

96-
def __init__(self, module: str, name: str, allow_fail: bool = False) -> None:
100+
def __init__(
101+
self,
102+
module: str,
103+
name: str,
104+
allow_fail: bool = False,
105+
condition: Optional[Callable[[], bool]] = None,
106+
) -> None:
97107
if module in {"__builtin__", "__builtins__"}:
98108
module = "builtins"
99109
ResolvableType.__init__(self, f"ModuleType[{module}.{name}]")
100110
self._name = name
101111
self._module = module
102112
self._allow_fail = allow_fail
103-
104-
def __new__(cls, module: str, name: str, allow_fail: bool = False) -> Self:
113+
self._condition = condition
114+
115+
def __new__(
116+
cls,
117+
module: str,
118+
name: str,
119+
allow_fail: bool = False,
120+
condition: Optional[Callable[[], bool]] = None,
121+
) -> Self:
105122
return ResolvableType.__new__(cls, f"ModuleType[{module}.{name}]")
106123

107124
def retrieve(self) -> bool:
@@ -111,6 +128,10 @@ def retrieve(self) -> bool:
111128
bool: Whether the retrieval succeeded.
112129
"""
113130
if self._type is None and self._module in sys.modules:
131+
# If a condition is given, check the condition before attempting to import.
132+
if self._condition is not None and not self._condition():
133+
return False
134+
114135
type = sys.modules[self._module]
115136
for name in self._name.split("."):
116137
# If `type` does not contain `name` and `self._allow_fail` is

tests/test_type.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ def test_moduletype_allow_fail():
8686
assert not t_allowed.retrieve()
8787

8888

89+
def test_moduletype_condition():
90+
store = {"condition": False}
91+
t = ModuleType("builtins", "int", condition=lambda: store["condition"])
92+
assert not t.retrieve()
93+
store["condition"] = True
94+
assert t.retrieve()
95+
96+
8997
def test_is_hint():
9098
assert not _is_hint(int)
9199
assert _is_hint(typing.Union[int, float])

0 commit comments

Comments
 (0)