Skip to content

Commit a961b54

Browse files
Add missing Attachments WIP
1 parent bc27fab commit a961b54

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

bioimageio/core/resource_io/nodes.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,36 +139,41 @@ def __call__(self, *args, **kwargs):
139139

140140

141141
@dataclass
142-
class KerasHdf5WeightsEntry(model_raw_nodes.KerasHdf5WeightsEntry):
142+
class KerasHdf5WeightsEntry(Node, model_raw_nodes.KerasHdf5WeightsEntry):
143143
source: Path = missing
144144

145145

146146
@dataclass
147-
class OnnxWeightsEntry(model_raw_nodes.OnnxWeightsEntry):
147+
class OnnxWeightsEntry(Node, model_raw_nodes.OnnxWeightsEntry):
148148
source: Path = missing
149149

150150

151151
@dataclass
152-
class PytorchStateDictWeightsEntry(model_raw_nodes.PytorchStateDictWeightsEntry):
152+
class PytorchStateDictWeightsEntry(Node, model_raw_nodes.PytorchStateDictWeightsEntry):
153153
source: Path = missing
154154
architecture: Union[_Missing, ImportedSource] = missing
155155

156156

157157
@dataclass
158-
class PytorchScriptWeightsEntry(model_raw_nodes.PytorchScriptWeightsEntry):
158+
class PytorchScriptWeightsEntry(Node, model_raw_nodes.PytorchScriptWeightsEntry):
159159
source: Path = missing
160160

161161

162162
@dataclass
163-
class TensorflowJsWeightsEntry(model_raw_nodes.TensorflowJsWeightsEntry):
163+
class TensorflowJsWeightsEntry(Node, model_raw_nodes.TensorflowJsWeightsEntry):
164164
source: Path = missing
165165

166166

167167
@dataclass
168-
class TensorflowSavedModelBundleWeightsEntry(model_raw_nodes.TensorflowSavedModelBundleWeightsEntry):
168+
class TensorflowSavedModelBundleWeightsEntry(Node, model_raw_nodes.TensorflowSavedModelBundleWeightsEntry):
169169
source: Path = missing
170170

171171

172+
@dataclass
173+
class Attachments(Node, model_raw_nodes.Attachments):
174+
files: List[Path] = missing
175+
176+
172177
WeightsEntry = Union[
173178
KerasHdf5WeightsEntry,
174179
OnnxWeightsEntry,

tests/build_spec/test_build_spec.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def _test_build_spec(
9090
loaded_config = loaded_model.config
9191
assert "deepimagej" in loaded_config
9292

93-
attachments = loaded_model.attachments or {}
94-
if "files" in attachments:
93+
attachments = loaded_model.attachments
94+
if attachments is not missing and attachments.files is not missing:
9595
for attached_file in attachments["files"]:
9696
assert attached_file.exists()
9797

@@ -132,3 +132,7 @@ def test_build_spec_tfjs(any_tensorflow_js_model, tmp_path):
132132

133133
def test_build_spec_deepimagej(unet2d_nuclei_broad_model, tmp_path):
134134
_test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "pytorch_script", add_deepimagej_config=True)
135+
136+
137+
# def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path):
138+
# _test_build_spec(unet2d_keras, tmp_path / "model.zip", "pytorch_script", add_deepimagej_config=True)

0 commit comments

Comments
 (0)