Skip to content

Commit 98c1c43

Browse files
3293 Remove extra deep supervision modules of DynUNet (#3427)
* enhance dynunet Signed-off-by: Yiheng Wang <[email protected]> * fix black issue Signed-off-by: Yiheng Wang <[email protected]> * use strict=False Signed-off-by: Yiheng Wang <[email protected]> * fix black 21.12 error Signed-off-by: Yiheng Wang <[email protected]> * enhance code and update docstring Signed-off-by: Yiheng Wang <[email protected]>
1 parent a17813b commit 98c1c43

File tree

2 files changed

+62
-42
lines changed

2 files changed

+62
-42
lines changed

monai/networks/nets/dynunet.py

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ class DynUNetSkipLayer(nn.Module):
3131
forward passes of the network.
3232
"""
3333

34-
heads: List[torch.Tensor]
34+
heads: Optional[List[torch.Tensor]]
3535

36-
def __init__(self, index, heads, downsample, upsample, super_head, next_layer):
36+
def __init__(self, index, downsample, upsample, next_layer, heads=None, super_head=None):
3737
super().__init__()
3838
self.downsample = downsample
39-
self.upsample = upsample
4039
self.next_layer = next_layer
40+
self.upsample = upsample
4141
self.super_head = super_head
4242
self.heads = heads
4343
self.index = index
@@ -46,8 +46,8 @@ def forward(self, x):
4646
downout = self.downsample(x)
4747
nextout = self.next_layer(downout)
4848
upout = self.upsample(nextout, downout)
49-
50-
self.heads[self.index] = self.super_head(upout)
49+
if self.super_head is not None and self.heads is not None and self.index > 0:
50+
self.heads[self.index - 1] = self.super_head(upout)
5151

5252
return upout
5353

@@ -79,6 +79,8 @@ class DynUNet(nn.Module):
7979
For example, if `strides=((1, 2, 4), 2, 1, 1)`, the minimal spatial size of the input is `(8, 16, 32)`, and
8080
the spatial size of the output is `(8, 8, 8)`.
8181
82+
For backwards compatibility with old weights, please set `strict=False` when calling `load_state_dict`.
83+
8284
Usage example with medical segmentation decathlon dataset is available at:
8385
https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline.
8486
@@ -100,18 +102,16 @@ class DynUNet(nn.Module):
100102
norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``.
101103
act_name: activation layer type and arguments. Defaults to ``leakyrelu``.
102104
deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.
103-
If ``True``, in training mode, the forward function will output not only the last feature
104-
map, but also the previous feature maps that come from the intermediate up sample layers.
105+
If ``True``, in training mode, the forward function will output not only the final feature map
106+
(from `output_block`), but also the feature maps that come from the intermediate up sample layers.
105107
In order to unify the return type (the restriction of TorchScript), all intermediate
106-
feature maps are interpolated into the same size as the last feature map and stacked together
108+
feature maps are interpolated into the same size as the final feature map and stacked together
107109
(with a new dimension in the first axis)into one single tensor.
108-
For instance, if there are three feature maps with shapes: (1, 2, 32, 24), (1, 2, 16, 12) and
109-
(1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor
110-
will has the shape (1, 3, 2, 8, 6).
110+
For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and
111+
(1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps
112+
will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24).
111113
When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss
112114
one by one with the ground truth, then do a weighted average for all losses to achieve the final loss.
113-
(To be added: a corresponding tutorial link)
114-
115115
deep_supr_num: number of feature maps that will output during deep supervision head. The
116116
value should be larger than 0 and less than the number of up sample layers.
117117
Defaults to 1.
@@ -160,16 +160,17 @@ def __init__(
160160
self.upsamples = self.get_upsamples()
161161
self.output_block = self.get_output_block(0)
162162
self.deep_supervision = deep_supervision
163-
self.deep_supervision_heads = self.get_deep_supervision_heads()
164163
self.deep_supr_num = deep_supr_num
164+
# initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
165+
self.heads: List[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num
166+
if self.deep_supervision:
167+
self.deep_supervision_heads = self.get_deep_supervision_heads()
168+
self.check_deep_supr_num()
169+
165170
self.apply(self.initialize_weights)
166171
self.check_kernel_stride()
167-
self.check_deep_supr_num()
168172

169-
# initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
170-
self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1)
171-
172-
def create_skips(index, downsamples, upsamples, superheads, bottleneck):
173+
def create_skips(index, downsamples, upsamples, bottleneck, superheads=None):
173174
"""
174175
Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is
175176
done recursively from the top down since a recursive nn.Module subclass is being used to be compatible
@@ -180,30 +181,50 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck):
180181

181182
if len(downsamples) != len(upsamples):
182183
raise ValueError(f"{len(downsamples)} != {len(upsamples)}")
183-
if (len(downsamples) - len(superheads)) not in (1, 0):
184-
raise ValueError(f"{len(downsamples)}-(0,1) != {len(superheads)}")
185184

186185
if len(downsamples) == 0: # bottom of the network, pass the bottleneck block
187186
return bottleneck
187+
188+
if superheads is None:
189+
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck)
190+
return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer)
191+
192+
super_head_flag = False
188193
if index == 0: # don't associate a supervision head with self.input_block
189-
current_head, rest_heads = nn.Identity(), superheads
190-
elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one
191-
current_head, rest_heads = nn.Identity(), superheads[1:]
194+
rest_heads = superheads
192195
else:
193-
current_head, rest_heads = superheads[0], superheads[1:]
196+
if len(superheads) > 0:
197+
super_head_flag = True
198+
rest_heads = superheads[1:]
199+
else:
200+
rest_heads = nn.ModuleList()
194201

195202
# create the next layer down, this will stop at the bottleneck layer
196-
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck)
197-
198-
return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer)
199-
200-
self.skip_layers = create_skips(
201-
0,
202-
[self.input_block] + list(self.downsamples),
203-
self.upsamples[::-1],
204-
self.deep_supervision_heads,
205-
self.bottleneck,
206-
)
203+
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck, superheads=rest_heads)
204+
if super_head_flag:
205+
return DynUNetSkipLayer(
206+
index,
207+
downsample=downsamples[0],
208+
upsample=upsamples[0],
209+
next_layer=next_layer,
210+
heads=self.heads,
211+
super_head=superheads[0],
212+
)
213+
214+
return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer)
215+
216+
if not self.deep_supervision:
217+
self.skip_layers = create_skips(
218+
0, [self.input_block] + list(self.downsamples), self.upsamples[::-1], self.bottleneck
219+
)
220+
else:
221+
self.skip_layers = create_skips(
222+
0,
223+
[self.input_block] + list(self.downsamples),
224+
self.upsamples[::-1],
225+
self.bottleneck,
226+
superheads=self.deep_supervision_heads,
227+
)
207228

208229
def check_kernel_stride(self):
209230
kernels, strides = self.kernel_size, self.strides
@@ -242,8 +263,7 @@ def forward(self, x):
242263
out = self.output_block(out)
243264
if self.training and self.deep_supervision:
244265
out_all = [out]
245-
feature_maps = self.heads[1 : self.deep_supr_num + 1]
246-
for feature_map in feature_maps:
266+
for feature_map in self.heads:
247267
out_all.append(interpolate(feature_map, out.shape[2:]))
248268
return torch.stack(out_all, dim=1)
249269
return out
@@ -334,7 +354,7 @@ def get_module_list(
334354
return nn.ModuleList(layers)
335355

336356
def get_deep_supervision_heads(self):
337-
return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)])
357+
return nn.ModuleList([self.get_output_block(i + 1) for i in range(self.deep_supr_num)])
338358

339359
@staticmethod
340360
def initialize_weights(module):

tests/test_network_consistency.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import monai.networks.nets as nets
2323
from monai.utils import set_determinism
2424

25-
extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA", None)
25+
extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA")
2626

2727
TESTS = []
2828
if extra_test_data_dir is not None:
@@ -60,8 +60,8 @@ def test_network_consistency(self, net_name, data_path, json_path):
6060
json_file.close()
6161

6262
# Create model
63-
model = nets.__dict__[net_name](**model_params)
64-
model.load_state_dict(loaded_data["model"])
63+
model = getattr(nets, net_name)(**model_params)
64+
model.load_state_dict(loaded_data["model"], strict=False)
6565
model.eval()
6666

6767
in_data = loaded_data["in_data"]

0 commit comments

Comments
 (0)