Skip to content

Commit 9fc098f

Browse files
Fix issues in build_spec WIP
1 parent da34a74 commit 9fc098f

File tree

2 files changed

+55
-56
lines changed

2 files changed

+55
-56
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 44 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -53,63 +53,68 @@ def _infer_weight_type(path):
5353
raise ValueError(f"Could not infer weight type from extension {ext} for weight file {path}")
5454

5555

56-
def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
57-
weight_path = resolve_source(original_weight_source, root)
58-
if weight_type is None:
59-
weight_type = _infer_weight_type(weight_path)
60-
weight_hash = _get_hash(weight_path)
56+
def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
57+
assert architecture is not None
58+
tmp_archtecture = None
6159

62-
tmp_source = None
6360
# if we have a ":" (or deprecated "::") this is a python file with class specified,
6461
# so we can compute the hash for it
65-
if source is not None and ":" in source:
66-
source_file, source_class = source.replace("::", ":").split(":")
62+
if ":" in architecture:
63+
arch_file, arch_class = architecture.replace("::", ":").split(":")
6764

6865
# get the source path
69-
source_file = _ensure_local(source_file, root)
70-
source_hash = _get_hash(root / source_file)
66+
arch_file = _ensure_local(arch_file, root)
67+
arch_hash = _get_hash(root / arch_file)
7168

7269
# if not relative, create local copy (otherwise this will not work)
73-
if os.path.isabs(source_file):
74-
copyfile(source_file, "this_model_architecture.py")
75-
source = f"this_model_architecture.py:{source_class}"
76-
tmp_source = "this_model_architecture.py"
70+
if os.path.isabs(arch_file):
71+
copyfile(arch_file, "this_model_architecture.py")
72+
arch = f"this_model_architecture.py:{arch_class}"
73+
tmp_archtecture = "this_model_architecture.py"
7774
else:
78-
source = f"{source_file}:{source_class}"
79-
source = spec.shared.fields.ImportableSource().deserialize(source)
75+
arch = f"{arch_file}:{arch_class}"
76+
arch = spec.shared.fields.Importablearch().deserialize(arch)
77+
78+
weight_kwargs = {"architecture": arch, "architecture_sha256": arch_hash}
79+
80+
# otherwise this is a python class or function name
8081
else:
81-
source_hash = None
82+
weight_kwargs = {"architecture": architecture}
83+
84+
if model_kwargs is not None:
85+
weight_kwargs["kwargs"] = model_kwargs
86+
87+
return weight_kwargs, tmp_archtecture
88+
89+
90+
def _get_weights(original_weight_source, weight_type, root, architecture=None, model_kwargs=None, **kwargs):
91+
weight_path = resolve_source(original_weight_source, root)
92+
if weight_type is None:
93+
weight_type = _infer_weight_type(weight_path)
94+
weight_hash = _get_hash(weight_path)
8295

8396
attachments = {"attachments": kwargs["weight_attachments"]} if "weight_attachments" in kwargs else {}
8497
weight_types = model_spec.raw_nodes.WeightsFormat
8598
weight_source = _ensure_local_or_url(original_weight_source, root)
8699

100+
tmp_archtecture = None
87101
if weight_type == "pytorch_state_dict":
88-
# pytorch-state-dict -> we need a source
89-
assert source is not None
102+
# pytorch-state-dict -> we need an architecture definition
103+
weight_kwargs, tmp_file = _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root)
104+
weight_kwargs.update(**attachments)
90105
weights = model_spec.raw_nodes.PytorchStateDictWeightsEntry(
91-
source=weight_source, sha256=weight_hash, **attachments
106+
source=weight_source, sha256=weight_hash, **weight_kwargs
92107
)
93-
language = "python"
94-
framework = "pytorch"
95108

96109
elif weight_type == "onnx":
97110
weights = model_spec.raw_nodes.OnnxWeightsEntry(
98111
source=weight_source, sha256=weight_hash, opset_version=kwargs.get("opset_version", 12), **attachments
99112
)
100-
language = None
101-
framework = None
102113

103114
elif weight_type == "pytorch_script":
104115
weights = model_spec.raw_nodes.PytorchScriptWeightsEntry(
105116
source=weight_source, sha256=weight_hash, **attachments
106117
)
107-
if source is None:
108-
language = None
109-
framework = None
110-
else:
111-
language = "python"
112-
framework = "pytorch"
113118

114119
elif weight_type == "keras_hdf5":
115120
weights = model_spec.raw_nodes.KerasHdf5WeightsEntry(
@@ -118,8 +123,6 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
118123
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
119124
**attachments,
120125
)
121-
language = "python"
122-
framework = "tensorflow"
123126

124127
elif weight_type == "tensorflow_saved_model_bundle":
125128
weights = model_spec.raw_nodes.TensorflowSavedModelBundleWeightsEntry(
@@ -128,8 +131,6 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
128131
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
129132
**attachments,
130133
)
131-
language = "python"
132-
framework = "tensorflow"
133134

134135
elif weight_type == "tensorflow_js":
135136
weights = model_spec.raw_nodes.TensorflowJsWeightsEntry(
@@ -138,16 +139,14 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
138139
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
139140
**attachments,
140141
)
141-
language = None
142-
framework = None
143142

144143
elif weight_type in weight_types:
145144
raise ValueError(f"Weight type {weight_type} is not supported yet in 'build_spec'")
146145
else:
147146
raise ValueError(f"Invalid weight type {weight_type}, expect one of {weight_types}")
148147

149148
weights = {weight_type: weights}
150-
return weights, language, framework, source, source_hash, tmp_source
149+
return weights, tmp_archtecture
151150

152151

153152
def _get_data_range(data_range, dtype):
@@ -449,7 +448,7 @@ def build_model(
449448
cite: Dict[str, str],
450449
output_path: Union[str, Path],
451450
# model specific optional
452-
source: Optional[str] = None,
451+
architecture: Optional[str] = None,
453452
model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None,
454453
weight_type: Optional[str] = None,
455454
sample_inputs: Optional[List[str]] = None,
@@ -626,9 +625,7 @@ def build_model(
626625
covers = _ensure_local(covers, root)
627626

628627
# parse the weights
629-
weights, language, framework, source, source_hash, tmp_source = _get_weights(
630-
weight_uri, weight_type, source, root, **weight_kwargs
631-
)
628+
weights, tmp_archtecture = _get_weights(weight_uri, weight_type, root, architecture, model_kwargs, **weight_kwargs)
632629

633630
# validate the sample inputs and outputs (if given)
634631
if sample_inputs is not None:
@@ -690,11 +687,6 @@ def build_model(
690687
"run_mode": run_mode,
691688
"sample_inputs": sample_inputs,
692689
"sample_outputs": sample_outputs,
693-
"framework": framework,
694-
"language": language,
695-
"source": source,
696-
"sha256": source_hash,
697-
"kwargs": model_kwargs,
698690
"links": links,
699691
}
700692
kwargs = {k: v for k, v in optional_kwargs.items() if v is not None}
@@ -729,8 +721,8 @@ def build_model(
729721
except Exception as e:
730722
raise e
731723
finally:
732-
if tmp_source is not None:
733-
os.remove(tmp_source)
724+
if tmp_archtecture is not None:
725+
os.remove(tmp_archtecture)
734726

735727
model = load_raw_resource_description(model_package)
736728
return model
@@ -744,12 +736,12 @@ def add_weights(
744736
**weight_kwargs,
745737
):
746738
"""Add weight entry to bioimage.io model."""
747-
# we need to patss the weight path as abs path to avoid confusion with different root directories
748-
new_weights = _get_weights(Path(weight_uri).absolute(), weight_type, source=None, root=Path("."), **weight_kwargs)[
749-
0
750-
]
739+
# we need to pass the weight path as abs path to avoid confusion with different root directories
740+
new_weights, tmp_arch = _get_weights(Path(weight_uri).absolute(), weight_type, root=Path("."), **weight_kwargs)
751741
model.weights.update(new_weights)
752742
if output_path is not None:
753743
model_package = export_resource_package(model, output_path=output_path)
754744
model = load_raw_resource_description(model_package)
745+
if tmp_arch is not None:
746+
os.remove(tmp_arch)
755747
return model

tests/build_spec/test_build_spec.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,19 @@ def _test_build_spec(
2222
cite = {entry.text: entry.doi if entry.url is missing else entry.url for entry in model_spec.cite}
2323

2424
if weight_type == "pytorch_state_dict":
25-
source_path = model_spec.source.source_file
26-
class_name = model_spec.source.callable_name
25+
weight_spec = model_spec.weights["pytorch_state_dict"]
26+
source_path = weight_spec.architecture.source_file
27+
class_name = weight_spec.architecture.callable_name
28+
model_kwargs = None if weight_spec.kwargs is missing else weight_spec.kwargs
2729
model_source = f"{source_path}:{class_name}"
2830
weight_type_ = None # the weight type can be auto-detected
2931
elif weight_type == "pytorch_script":
3032
model_source = None
33+
model_kwargs = None
3134
weight_type_ = "pytorch_script" # the weight type CANNOT be auto-detcted
3235
else:
3336
model_source = None
37+
model_kwargs = None
3438
weight_type_ = None # the weight type can be auto-detected
3539

3640
dep_file = None if model_spec.dependencies is missing else resolve_source(model_spec.dependencies.file, root)
@@ -45,8 +49,6 @@ def _test_build_spec(
4549
for output in model_spec.outputs
4650
]
4751
kwargs = dict(
48-
source=model_source,
49-
model_kwargs=model_spec.kwargs,
5052
weight_uri=weight_source,
5153
test_inputs=resolve_source(model_spec.test_inputs, root),
5254
test_outputs=resolve_source(model_spec.test_outputs, root),
@@ -66,6 +68,11 @@ def _test_build_spec(
6668
output_path=out_path,
6769
add_deepimagej_config=add_deepimagej_config,
6870
)
71+
# TODO names
72+
if model_source is not None:
73+
kwargs["source"] = model_source
74+
if model_kwargs is not None:
75+
kwargs["kwargs"] = model_kwargs
6976
if tensorflow_version is not None:
7077
kwargs["tensorflow_version"] = tensorflow_version
7178
if use_implicit_output_shape:

0 commit comments

Comments
 (0)