Skip to content

Commit e90c621

Browse files
committed
update architecture handling in build_model
1 parent 6d7be5c commit e90c621

File tree

2 files changed

+21
-32
lines changed

2 files changed

+21
-32
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from bioimageio.core import export_resource_package, load_raw_resource_description
1414
from bioimageio.core.resource_io.nodes import URI
1515
from bioimageio.core.resource_io.utils import resolve_local_source, resolve_source
16+
from bioimageio.spec.shared.raw_nodes import ImportableSourceFile, ImportableModule
1617

1718
try:
1819
from typing import get_args
@@ -56,33 +57,23 @@ def _infer_weight_type(path):
5657
def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
5758
assert architecture is not None
5859
tmp_archtecture = None
59-
60-
# if we have a ":" (or deprecated "::") this is a python file with class specified,
61-
# so we can compute the hash for it
62-
if ":" in architecture:
63-
arch_file, arch_class = architecture.replace("::", ":").split(":")
64-
65-
# get the source path
66-
arch_file = _ensure_local(arch_file, root)
67-
arch_hash = _get_hash(root / arch_file)
68-
69-
# if not relative, create local copy (otherwise this will not work)
70-
if os.path.isabs(arch_file):
71-
copyfile(arch_file, "this_model_architecture.py")
72-
arch = f"this_model_architecture.py:{arch_class}"
73-
tmp_archtecture = "this_model_architecture.py"
74-
else:
75-
arch = f"{arch_file}:{arch_class}"
76-
arch = spec.shared.fields.Importablearch().deserialize(arch)
77-
78-
weight_kwargs = {"architecture": arch, "architecture_sha256": arch_hash}
79-
80-
# otherwise this is a python class or function name
60+
weight_kwargs = {"kwargs": model_kwargs} if model_kwargs else {}
61+
arch = spec.shared.fields.ImportableSource().deserialize(architecture)
62+
if isinstance(arch, ImportableSourceFile):
63+
if os.path.isabs(arch.source_file):
64+
tmp_archtecture = Path("this_model_architecture.py")
65+
66+
copyfile(arch.source_file, root / tmp_archtecture)
67+
arch = ImportableSourceFile(arch.callable_name, tmp_archtecture)
68+
69+
arch_hash = _get_hash(root / arch.source_file)
70+
weight_kwargs["architecture_sha256"] = arch_hash
71+
elif isinstance(arch, ImportableModule):
72+
pass
8173
else:
82-
weight_kwargs = {"architecture": architecture}
74+
raise NotImplementedError(arch)
8375

84-
if model_kwargs is not None:
85-
weight_kwargs["kwargs"] = model_kwargs
76+
weight_kwargs["architecture"] = arch
8677

8778
return weight_kwargs, tmp_archtecture
8879

tests/build_spec/test_build_spec.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,15 @@ def _test_build_spec(
2323

2424
if weight_type == "pytorch_state_dict":
2525
weight_spec = model_spec.weights["pytorch_state_dict"]
26-
source_path = weight_spec.architecture.source_file
27-
class_name = weight_spec.architecture.callable_name
2826
model_kwargs = None if weight_spec.kwargs is missing else weight_spec.kwargs
29-
model_source = f"{source_path}:{class_name}"
27+
architecture = str(weight_spec.architecture)
3028
weight_type_ = None # the weight type can be auto-detected
3129
elif weight_type == "pytorch_script":
32-
model_source = None
30+
architecture = None
3331
model_kwargs = None
3432
weight_type_ = "pytorch_script" # the weight type CANNOT be auto-detcted
3533
else:
36-
model_source = None
34+
architecture = None
3735
model_kwargs = None
3836
weight_type_ = None # the weight type can be auto-detected
3937

@@ -69,8 +67,8 @@ def _test_build_spec(
6967
add_deepimagej_config=add_deepimagej_config,
7068
)
7169
# TODO names
72-
if model_source is not None:
73-
kwargs["source"] = model_source
70+
if architecture is not None:
71+
kwargs["architecture"] = architecture
7472
if model_kwargs is not None:
7573
kwargs["kwargs"] = model_kwargs
7674
if tensorflow_version is not None:

0 commit comments

Comments
 (0)