Skip to content

Commit 2be9913

Browse files
committed
fix _output_axes
1 parent 73063bd commit 2be9913

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

bioimageio/core/model_adapters/_keras_model_adapter.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
from loguru import logger
55
from numpy.typing import NDArray
66

7-
from bioimageio.core.tensor import Tensor
87
from bioimageio.spec._internal.io_utils import download
98
from bioimageio.spec.model import v0_4, v0_5
109
from bioimageio.spec.model.v0_5 import Version
1110

1211
from .._settings import settings
12+
from ..digest_spec import get_axes_infos
13+
from ..tensor import Tensor
1314
from ._model_adapter import ModelAdapter
1415

1516
os.environ["KERAS_BACKEND"] = settings.keras_backend
@@ -74,7 +75,10 @@ def __init__(
7475
weight_path = download(model_description.weights.keras_hdf5.source).path
7576

7677
self._network = keras.models.load_model(weight_path)
77-
self._output_axes = [tuple(out.axes) for out in model_description.outputs]
78+
self._output_axes = [
79+
tuple(a.id for a in get_axes_infos(out))
80+
for out in model_description.outputs
81+
]
7882

7983
def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
8084
_result: Union[Sequence[NDArray[Any]], NDArray[Any]]
@@ -87,7 +91,11 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
8791
result = [_result] # type: ignore
8892

8993
assert len(result) == len(self._output_axes)
90-
return [Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)]
94+
ret: List[Optional[Tensor]] = []
95+
ret.extend(
96+
[Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)]
97+
)
98+
return ret
9199

92100
def unload(self) -> None:
93101
logger.warning(

bioimageio/core/model_adapters/_pytorch_model_adapter.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from bioimageio.spec.utils import download
77

88
from ..axis import AxisId
9-
from ..digest_spec import import_callable
9+
from ..digest_spec import get_axes_infos, import_callable
1010
from ..tensor import Tensor
1111
from ._model_adapter import ModelAdapter
1212

@@ -31,10 +31,7 @@ def __init__(
3131
if torch is None:
3232
raise ImportError("torch")
3333
super().__init__()
34-
self.output_dims = [
35-
tuple(AxisId(a) if isinstance(a, str) else a.id for a in out.axes)
36-
for out in outputs
37-
]
34+
self.output_dims = [tuple(a.id for a in get_axes_infos(out)) for out in outputs]
3835
self._network = self.get_network(weights)
3936
self._devices = self.get_devices(devices)
4037
self._network = self._network.to(self._devices[0])

0 commit comments

Comments
 (0)