11import abc
22from dataclasses import dataclass
3- from typing import Iterable , List , Optional , Sequence , Tuple
3+ from typing import Iterable , List , Optional , Sequence , Tuple , Union
44
55import xarray as xr
66from marshmallow import missing
77
88from bioimageio .core .resource_io import nodes
9- from bioimageio .core . resource_io . nodes import InputTensor , Model , OutputTensor
9+ from bioimageio .spec . model import raw_nodes
1010from ._combined_processing import CombinedProcessing
1111from ._model_adapters import ModelAdapter , create_model_adapter
1212from ._stat_state import StatsState
1313from ._utils import ComputedMeasures , Sample , TensorName
14+ from .. import load_resource_description
15+ from ..resource_io .utils import resolve_raw_node
1416
1517
1618@dataclass
@@ -54,15 +56,15 @@ def name(self) -> str:
5456
5557 @property
5658 @abc .abstractmethod
57- def input_specs (self ) -> List [InputTensor ]:
59+ def input_specs (self ) -> List [nodes . InputTensor ]:
5860 """
5961 specs of inputs
6062 """
6163 ...
6264
6365 @property
6466 @abc .abstractmethod
65- def output_specs (self ) -> List [OutputTensor ]:
67+ def output_specs (self ) -> List [nodes . OutputTensor ]:
6668 """
6769 specs of outputs
6870 """
@@ -88,7 +90,7 @@ def __init__(
8890 self ,
8991 * ,
9092 name : str ,
91- bioimageio_model : Model ,
93+ bioimageio_model : Union [ nodes . Model , raw_nodes . Model ] ,
9294 preprocessing : CombinedProcessing ,
9395 postprocessing : CombinedProcessing ,
9496 ipt_stats : StatsState ,
@@ -99,8 +101,14 @@ def __init__(
99101 raise NotImplementedError (f"Not yet implemented inference for run mode '{ bioimageio_model .run_mode .name } '" )
100102
101103 self ._name = name
102- self ._input_specs = bioimageio_model .inputs
103- self ._output_specs = bioimageio_model .outputs
104+ if isinstance (bioimageio_model , nodes .Model ):
105+ self ._input_specs = bioimageio_model .inputs
106+ self ._output_specs = bioimageio_model .outputs
107+ else :
108+ assert isinstance (bioimageio_model , raw_nodes .Model )
109+ self ._input_specs = [resolve_raw_node (s , nodes ) for s in bioimageio_model .inputs ]
110+ self ._output_specs = [resolve_raw_node (s , nodes ) for s in bioimageio_model .outputs ]
111+
104112 self ._preprocessing = preprocessing
105113 self ._postprocessing = postprocessing
106114 self ._ipt_stats = ipt_stats
@@ -176,7 +184,7 @@ def unload(self):
176184
177185
178186def create_prediction_pipeline (
179- bioimageio_model : nodes .Model ,
187+ bioimageio_model : Union [ nodes .Model , raw_nodes . Model ] ,
180188 * ,
181189 devices : Optional [Sequence [str ]] = None ,
182190 weight_format : Optional [str ] = None ,
@@ -196,8 +204,16 @@ def create_prediction_pipeline(
196204 model_adapter : ModelAdapter = model_adapter or create_model_adapter (
197205 bioimageio_model = bioimageio_model , devices = devices , weight_format = weight_format
198206 )
207+ if isinstance (bioimageio_model , nodes .Model ):
208+ ipts = bioimageio_model .inputs
209+ outs = bioimageio_model .outputs
210+
211+ else :
212+ assert isinstance (bioimageio_model , raw_nodes .Model )
213+ ipts = [resolve_raw_node (s , nodes ) for s in bioimageio_model .inputs ]
214+ outs = [resolve_raw_node (s , nodes ) for s in bioimageio_model .outputs ]
199215
200- preprocessing = CombinedProcessing (bioimageio_model . inputs )
216+ preprocessing = CombinedProcessing (ipts )
201217
202218 def sample_dataset ():
203219 for tensors in dataset_for_initial_statistics :
@@ -209,7 +225,7 @@ def sample_dataset():
209225 update_dataset_stats_after_n_samples = update_dataset_stats_after_n_samples ,
210226 update_dataset_stats_for_n_samples = update_dataset_stats_for_n_samples ,
211227 )
212- postprocessing = CombinedProcessing (bioimageio_model . outputs )
228+ postprocessing = CombinedProcessing (outs )
213229 out_stats = StatsState (
214230 postprocessing .required_measures ,
215231 dataset = tuple (),
0 commit comments