|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import importlib.util |
4 | | -from functools import singledispatch |
5 | 4 | from itertools import chain |
6 | 5 | from typing import ( |
7 | 6 | Any, |
|
20 | 19 | from numpy.typing import NDArray |
21 | 20 | from typing_extensions import Unpack, assert_never |
22 | 21 |
|
23 | | -from bioimageio.spec._internal.io import HashKwargs, download |
| 22 | +from bioimageio.spec._internal.io_utils import HashKwargs, download |
24 | 23 | from bioimageio.spec.common import FileSource |
25 | 24 | from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 |
26 | 25 | from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile |
|
44 | 43 | from .tensor import Tensor |
45 | 44 |
|
46 | 45 |
|
47 | | -@singledispatch |
48 | | -def import_callable(node: type, /) -> Callable[..., Any]: |
| 46 | +def import_callable( |
| 47 | + node: Union[CallableFromDepencency, ArchitectureFromLibraryDescr], |
| 48 | + /, |
| 49 | + **kwargs: Unpack[HashKwargs], |
| 50 | +) -> Callable[..., Any]: |
49 | 51 | """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" |
50 | | - raise TypeError(type(node)) |
51 | | - |
52 | | - |
53 | | -@import_callable.register |
54 | | -def _(node: CallableFromDepencency, **kwargs: Unpack[HashKwargs]) -> Callable[..., Any]: |
55 | | - module = importlib.import_module(node.module_name) |
56 | | - c = getattr(module, str(node.callable_name)) |
57 | | - if not callable(c): |
58 | | - raise ValueError(f"{node} (imported: {c}) is not callable") |
59 | | - |
60 | | - return c |
| 52 | + if isinstance(node, CallableFromDepencency): |
| 53 | + module = importlib.import_module(node.module_name) |
| 54 | + c = getattr(module, str(node.callable_name)) |
| 55 | + elif isinstance(node, ArchitectureFromLibraryDescr): |
| 56 | + module = importlib.import_module(node.import_from) |
| 57 | + c = getattr(module, str(node.callable)) |
| 58 | + elif isinstance(node, CallableFromFile): |
| 59 | + c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) |
| 60 | + elif isinstance(node, ArchitectureFromFileDescr): |
| 61 | + c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) |
61 | 62 |
|
| 63 | + else: |
| 64 | + assert_never(node) |
62 | 65 |
|
63 | | -@import_callable.register |
64 | | -def _( |
65 | | - node: ArchitectureFromLibraryDescr, **kwargs: Unpack[HashKwargs] |
66 | | -) -> Callable[..., Any]: |
67 | | - module = importlib.import_module(node.import_from) |
68 | | - c = getattr(module, str(node.callable)) |
69 | 66 | if not callable(c): |
70 | 67 | raise ValueError(f"{node} (imported: {c}) is not callable") |
71 | 68 |
|
72 | 69 | return c |
73 | 70 |
|
74 | 71 |
|
75 | | -@import_callable.register |
76 | | -def _(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): |
77 | | - return _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) |
78 | | - |
79 | | - |
80 | | -@import_callable.register |
81 | | -def _(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwargs]): |
82 | | - return _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) |
83 | | - |
84 | | - |
85 | 72 | def _import_from_file_impl( |
86 | 73 | source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] |
87 | 74 | ): |
|
0 commit comments