44from loguru import logger
55from numpy .typing import NDArray
66
7- from bioimageio .core .tensor import Tensor
87from bioimageio .spec ._internal .io_utils import download
98from bioimageio .spec .model import v0_4 , v0_5
109from bioimageio .spec .model .v0_5 import Version
1110
1211from .._settings import settings
12+ from ..digest_spec import get_axes_infos
13+ from ..tensor import Tensor
1314from ._model_adapter import ModelAdapter
1415
1516os .environ ["KERAS_BACKEND" ] = settings .keras_backend
@@ -74,7 +75,10 @@ def __init__(
7475 weight_path = download (model_description .weights .keras_hdf5 .source ).path
7576
7677 self ._network = keras .models .load_model (weight_path )
77- self ._output_axes = [tuple (out .axes ) for out in model_description .outputs ]
78+ self ._output_axes = [
79+ tuple (a .id for a in get_axes_infos (out ))
80+ for out in model_description .outputs
81+ ]
7882
7983 def forward (self , * input_tensors : Optional [Tensor ]) -> List [Optional [Tensor ]]:
8084 _result : Union [Sequence [NDArray [Any ]], NDArray [Any ]]
@@ -87,7 +91,11 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
8791 result = [_result ] # type: ignore
8892
8993 assert len (result ) == len (self ._output_axes )
90- return [Tensor (r , dims = axes ) for r , axes , in zip (result , self ._output_axes )]
94+ ret : List [Optional [Tensor ]] = []
95+ ret .extend (
96+ [Tensor (r , dims = axes ) for r , axes , in zip (result , self ._output_axes )]
97+ )
98+ return ret
9199
92100 def unload (self ) -> None :
93101 logger .warning (
0 commit comments