|
13 | 13 | from bioimageio.core import export_resource_package, load_raw_resource_description |
14 | 14 | from bioimageio.core.resource_io.nodes import URI |
15 | 15 | from bioimageio.core.resource_io.utils import resolve_local_source, resolve_source |
| 16 | +from bioimageio.spec.shared.raw_nodes import ImportableSourceFile, ImportableModule |
16 | 17 |
|
17 | 18 | try: |
18 | 19 | from typing import get_args |
@@ -56,33 +57,23 @@ def _infer_weight_type(path): |
56 | 57 | def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root): |
57 | 58 | assert architecture is not None |
58 | 59 | tmp_archtecture = None |
59 | | - |
60 | | - # if we have a ":" (or deprecated "::") this is a python file with class specified, |
61 | | - # so we can compute the hash for it |
62 | | - if ":" in architecture: |
63 | | - arch_file, arch_class = architecture.replace("::", ":").split(":") |
64 | | - |
65 | | - # get the source path |
66 | | - arch_file = _ensure_local(arch_file, root) |
67 | | - arch_hash = _get_hash(root / arch_file) |
68 | | - |
69 | | - # if not relative, create local copy (otherwise this will not work) |
70 | | - if os.path.isabs(arch_file): |
71 | | - copyfile(arch_file, "this_model_architecture.py") |
72 | | - arch = f"this_model_architecture.py:{arch_class}" |
73 | | - tmp_archtecture = "this_model_architecture.py" |
74 | | - else: |
75 | | - arch = f"{arch_file}:{arch_class}" |
76 | | - arch = spec.shared.fields.Importablearch().deserialize(arch) |
77 | | - |
78 | | - weight_kwargs = {"architecture": arch, "architecture_sha256": arch_hash} |
79 | | - |
80 | | - # otherwise this is a python class or function name |
| 60 | + weight_kwargs = {"kwargs": model_kwargs} if model_kwargs else {} |
| 61 | + arch = spec.shared.fields.ImportableSource().deserialize(architecture) |
| 62 | + if isinstance(arch, ImportableSourceFile): |
| 63 | + if os.path.isabs(arch.source_file): |
| 64 | + tmp_archtecture = Path("this_model_architecture.py") |
| 65 | + |
| 66 | + copyfile(arch.source_file, root / tmp_archtecture) |
| 67 | + arch = ImportableSourceFile(arch.callable_name, tmp_archtecture) |
| 68 | + |
| 69 | + arch_hash = _get_hash(root / arch.source_file) |
| 70 | + weight_kwargs["architecture_sha256"] = arch_hash |
| 71 | + elif isinstance(arch, ImportableModule): |
| 72 | + pass |
81 | 73 | else: |
82 | | - weight_kwargs = {"architecture": architecture} |
| 74 | + raise NotImplementedError(arch) |
83 | 75 |
|
84 | | - if model_kwargs is not None: |
85 | | - weight_kwargs["kwargs"] = model_kwargs |
| 76 | + weight_kwargs["architecture"] = arch |
86 | 77 |
|
87 | 78 | return weight_kwargs, tmp_archtecture |
88 | 79 |
|
|
0 commit comments