Skip to content

Commit da34a74

Browse files
Update to new pytorch_state_dict weight spec
1 parent a6c41f3 commit da34a74

File tree

3 files changed

+9
-14
lines changed

3 files changed

+9
-14
lines changed

bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22
from typing import List, Optional, Sequence
33

4+
# by default, we use the keras integrated with tensorflow
45
try:
56
from tensorflow import keras
67
except Exception:

bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ def _unload(self) -> None:
5252

5353
@staticmethod
5454
def get_nn_instance(model_node: nodes.Model, **kwargs):
55-
assert isinstance(model_node.source, nodes.ImportedSource)
56-
57-
joined_kwargs = {} if model_node.kwargs is missing else dict(model_node.kwargs)
55+
weight_spec = model_node.weights.get("pytorch_state_dict")
56+
assert weight_spec is not None
57+
assert isinstance(weight_spec.architecture, nodes.ImportedSource)
58+
model_kwargs = weight_spec.kwargs
59+
joined_kwargs = {} if model_kwargs is missing else dict(model_kwargs)
5860
joined_kwargs.update(kwargs)
59-
return model_node.source(**joined_kwargs)
61+
return weight_spec.architecture(**joined_kwargs)

bioimageio/core/weight_converter/torch/utils.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,11 @@
11
import torch
2-
from marshmallow import missing
3-
4-
5-
# NOTE: copied from tiktorch; this should go into python-bioimageio and then we use it from there
6-
def get_nn_instance(node, **kwargs):
7-
joined_kwargs = {} if node.kwargs is missing else dict(node.kwargs) # type: ignore
8-
joined_kwargs.update(kwargs)
9-
model = node.source(**joined_kwargs)
10-
return model
2+
from bioimageio.core.prediction_pipeline._model_adapters._pytorch_model_adapter import PytorchModelAdapter
113

124

135
# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too
146
# and for each weight format
157
def load_model(node):
16-
model = get_nn_instance(node)
8+
model = PytorchModelAdapter.get_nn_instance(node)
179
state = torch.load(node.weights["pytorch_state_dict"].source)
1810
model.load_state_dict(state)
1911
model.eval()

0 commit comments

Comments
 (0)