Skip to content

Commit 8a454bd

Browse files
authored
fixes #3331 (#3334)
Signed-off-by: Wenqi Li <[email protected]>
1 parent af21f11 commit 8a454bd

File tree

2 files changed

+4
-16
lines changed

2 files changed

+4
-16
lines changed

monai/networks/nets/segresnet.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def __init__(
235235
in_channels=in_channels,
236236
out_channels=out_channels,
237237
dropout_prob=dropout_prob,
238+
act=act,
238239
norm=norm,
239240
use_conv_final=use_conv_final,
240241
blocks_down=blocks_down,
@@ -318,25 +319,11 @@ def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor):
318319

319320
def forward(self, x):
320321
net_input = x
321-
x = self.convInit(x)
322-
if self.dropout_prob is not None:
323-
x = self.dropout(x)
324-
325-
down_x = []
326-
for down in self.down_layers:
327-
x = down(x)
328-
down_x.append(x)
329-
322+
x, down_x = self.encode(x)
330323
down_x.reverse()
331324

332325
vae_input = x
333-
334-
for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)):
335-
x = up(x) + down_x[i + 1]
336-
x = upl(x)
337-
338-
if self.use_conv_final:
339-
x = self.conv_final(x)
326+
x = self.decode(x, down_x)
340327

341328
if self.training:
342329
vae_loss = self._get_vae_loss(net_input, vae_input)

tests/test_segresnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
"init_filters": init_filters,
7171
"out_channels": out_channels,
7272
"upsample_mode": upsample_mode,
73+
"act": ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
7374
"input_image_size": ([16] * spatial_dims),
7475
"vae_estimate_std": vae_estimate_std,
7576
},

0 commit comments

Comments
 (0)