File tree Expand file tree Collapse file tree 3 files changed +9
-14
lines changed
prediction_pipeline/_model_adapters Expand file tree Collapse file tree 3 files changed +9
-14
lines changed Original file line number Diff line number Diff line change 11import warnings
22from typing import List , Optional , Sequence
33
4+ # by default, we use the keras integrated with tensorflow
45try :
56 from tensorflow import keras
67except Exception :
Original file line number Diff line number Diff line change @@ -52,8 +52,10 @@ def _unload(self) -> None:
5252
5353 @staticmethod
5454 def get_nn_instance (model_node : nodes .Model , ** kwargs ):
55- assert isinstance (model_node .source , nodes .ImportedSource )
56-
57- joined_kwargs = {} if model_node .kwargs is missing else dict (model_node .kwargs )
55+ weight_spec = model_node .weights .get ("pytorch_state_dict" )
56+ assert weight_spec is not None
57+ assert isinstance (weight_spec .architecture , nodes .ImportedSource )
58+ model_kwargs = weight_spec .kwargs
59+ joined_kwargs = {} if model_kwargs is missing else dict (model_kwargs )
5860 joined_kwargs .update (kwargs )
59- return model_node . source (** joined_kwargs )
61+ return weight_spec . architecture (** joined_kwargs )
Original file line number Diff line number Diff line change 11import torch
2- from marshmallow import missing
3-
4-
5- # NOTE: copied from tiktorch; this should go into python-bioimageio and then we use it from there
6- def get_nn_instance (node , ** kwargs ):
7- joined_kwargs = {} if node .kwargs is missing else dict (node .kwargs ) # type: ignore
8- joined_kwargs .update (kwargs )
9- model = node .source (** joined_kwargs )
10- return model
2+ from bioimageio .core .prediction_pipeline ._model_adapters ._pytorch_model_adapter import PytorchModelAdapter
113
124
135# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too
146# and for each weight format
157def load_model (node ):
16- model = get_nn_instance (node )
8+ model = PytorchModelAdapter . get_nn_instance (node )
179 state = torch .load (node .weights ["pytorch_state_dict" ].source )
1810 model .load_state_dict (state )
1911 model .eval ()
You can’t perform that action at this time.
0 commit comments