Skip to content

Commit a5c68f2

Browse files
committed
make multi tensor model onnx compatible (again)
1 parent 92188be commit a5c68f2

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

example_descriptions/models/unet2d_multi_tensor/bioimageio.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ weights:
143143
in_channels: 2
144144
initial_features: 16
145145
out_channels: 2
146-
sha256: 89204f8f3513b3c227127a8137bedaec9eafe49925f7734c73c6650ec135b34e
146+
sha256: 74b6d27cd17b40560e70fde4d57b5675614979ffab4f1e0328a6c3a3f64a1ff2
147147
source: multi_tensor_unet.py
148148
pytorch_version: 1.6
149149
sha256: c498522b3f2b02429b41fe9dbcb722ce0d7ad4cae7fcf8059cee27857ae49b00

example_descriptions/models/unet2d_multi_tensor/multi_tensor_unet.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# type: ignore
2+
from typing import List, Optional
3+
24
import torch
35
import torch.nn as nn
46

@@ -75,11 +77,16 @@ def _apply_default(self, x):
7577

7678
return x
7779

78-
def forward(self, *x):
79-
assert isinstance(x, (list, tuple)), type(x)
80-
# fix issue in onnx export
81-
if isinstance(x[0], list) and len(x) == 1:
82-
x = x[0]
80+
def forward(
81+
self,
82+
x0: torch.Tensor,
83+
x1: Optional[torch.Tensor] = None,
84+
x2: Optional[torch.Tensor] = None,
85+
x3: Optional[torch.Tensor] = None,
86+
x4: Optional[torch.Tensor] = None,
87+
/,
88+
) -> List[torch.Tensor]:
89+
x = [x for x in [x0, x1, x2, x3, x4] if x is not None]
8390
assert len(x) == self.in_channels, f"{len(x)}, {self.in_channels}"
8491
x = torch.cat(x, dim=1)
8592
out = self._apply_default(x)

0 commit comments

Comments
 (0)