@@ -31,13 +31,13 @@ class DynUNetSkipLayer(nn.Module):
3131 forward passes of the network.
3232 """
3333
34- heads : List [torch .Tensor ]
34+ heads : Optional [ List [torch .Tensor ] ]
3535
36- def __init__ (self , index , heads , downsample , upsample , super_head , next_layer ):
36+ def __init__ (self , index , downsample , upsample , next_layer , heads = None , super_head = None ):
3737 super ().__init__ ()
3838 self .downsample = downsample
39- self .upsample = upsample
4039 self .next_layer = next_layer
40+ self .upsample = upsample
4141 self .super_head = super_head
4242 self .heads = heads
4343 self .index = index
@@ -46,8 +46,8 @@ def forward(self, x):
4646 downout = self .downsample (x )
4747 nextout = self .next_layer (downout )
4848 upout = self .upsample (nextout , downout )
49-
50- self .heads [self .index ] = self .super_head (upout )
49+ if self . super_head is not None and self . heads is not None and self . index > 0 :
50+ self .heads [self .index - 1 ] = self .super_head (upout )
5151
5252 return upout
5353
@@ -79,6 +79,8 @@ class DynUNet(nn.Module):
7979 For example, if `strides=((1, 2, 4), 2, 1, 1)`, the minimal spatial size of the input is `(8, 16, 32)`, and
8080 the spatial size of the output is `(8, 8, 8)`.
8181
82+ For backwards compatibility with old weights, please set `strict=False` when calling `load_state_dict`.
83+
8284 Usage example with medical segmentation decathlon dataset is available at:
8385 https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline.
8486
@@ -100,18 +102,16 @@ class DynUNet(nn.Module):
100102 norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``.
101103 act_name: activation layer type and arguments. Defaults to ``leakyrelu``.
102104 deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.
103- If ``True``, in training mode, the forward function will output not only the last feature
104- map , but also the previous feature maps that come from the intermediate up sample layers.
105+ If ``True``, in training mode, the forward function will output not only the final feature map
106+ (from `output_block`) , but also the feature maps that come from the intermediate up sample layers.
105107 In order to unify the return type (the restriction of TorchScript), all intermediate
106- feature maps are interpolated into the same size as the last feature map and stacked together
108+ feature maps are interpolated into the same size as the final feature map and stacked together
107109 (with a new dimension in the first axis)into one single tensor.
108- For instance, if there are three feature maps with shapes: (1, 2, 32, 24), (1, 2, 16, 12) and
109- (1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor
110- will has the shape (1, 3, 2, 8, 6 ).
110+ For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and
111+ (1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps
112+ will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24 ).
111113 When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss
112114 one by one with the ground truth, then do a weighted average for all losses to achieve the final loss.
113- (To be added: a corresponding tutorial link)
114-
115115 deep_supr_num: number of feature maps that will output during deep supervision head. The
116116 value should be larger than 0 and less than the number of up sample layers.
117117 Defaults to 1.
@@ -160,16 +160,17 @@ def __init__(
160160 self .upsamples = self .get_upsamples ()
161161 self .output_block = self .get_output_block (0 )
162162 self .deep_supervision = deep_supervision
163- self .deep_supervision_heads = self .get_deep_supervision_heads ()
164163 self .deep_supr_num = deep_supr_num
164+ # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
165+ self .heads : List [torch .Tensor ] = [torch .rand (1 )] * self .deep_supr_num
166+ if self .deep_supervision :
167+ self .deep_supervision_heads = self .get_deep_supervision_heads ()
168+ self .check_deep_supr_num ()
169+
165170 self .apply (self .initialize_weights )
166171 self .check_kernel_stride ()
167- self .check_deep_supr_num ()
168172
169- # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
170- self .heads : List [torch .Tensor ] = [torch .rand (1 )] * (len (self .deep_supervision_heads ) + 1 )
171-
172- def create_skips (index , downsamples , upsamples , superheads , bottleneck ):
173+ def create_skips (index , downsamples , upsamples , bottleneck , superheads = None ):
173174 """
174175 Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is
175176 done recursively from the top down since a recursive nn.Module subclass is being used to be compatible
@@ -180,30 +181,50 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck):
180181
181182 if len (downsamples ) != len (upsamples ):
182183 raise ValueError (f"{ len (downsamples )} != { len (upsamples )} " )
183- if (len (downsamples ) - len (superheads )) not in (1 , 0 ):
184- raise ValueError (f"{ len (downsamples )} -(0,1) != { len (superheads )} " )
185184
186185 if len (downsamples ) == 0 : # bottom of the network, pass the bottleneck block
187186 return bottleneck
187+
188+ if superheads is None :
189+ next_layer = create_skips (1 + index , downsamples [1 :], upsamples [1 :], bottleneck )
190+ return DynUNetSkipLayer (index , downsample = downsamples [0 ], upsample = upsamples [0 ], next_layer = next_layer )
191+
192+ super_head_flag = False
188193 if index == 0 : # don't associate a supervision head with self.input_block
189- current_head , rest_heads = nn .Identity (), superheads
190- elif not self .deep_supervision : # bypass supervision heads by passing nn.Identity in place of a real one
191- current_head , rest_heads = nn .Identity (), superheads [1 :]
194+ rest_heads = superheads
192195 else :
193- current_head , rest_heads = superheads [0 ], superheads [1 :]
196+ if len (superheads ) > 0 :
197+ super_head_flag = True
198+ rest_heads = superheads [1 :]
199+ else :
200+ rest_heads = nn .ModuleList ()
194201
195202 # create the next layer down, this will stop at the bottleneck layer
196- next_layer = create_skips (1 + index , downsamples [1 :], upsamples [1 :], rest_heads , bottleneck )
197-
198- return DynUNetSkipLayer (index , self .heads , downsamples [0 ], upsamples [0 ], current_head , next_layer )
199-
200- self .skip_layers = create_skips (
201- 0 ,
202- [self .input_block ] + list (self .downsamples ),
203- self .upsamples [::- 1 ],
204- self .deep_supervision_heads ,
205- self .bottleneck ,
206- )
203+ next_layer = create_skips (1 + index , downsamples [1 :], upsamples [1 :], bottleneck , superheads = rest_heads )
204+ if super_head_flag :
205+ return DynUNetSkipLayer (
206+ index ,
207+ downsample = downsamples [0 ],
208+ upsample = upsamples [0 ],
209+ next_layer = next_layer ,
210+ heads = self .heads ,
211+ super_head = superheads [0 ],
212+ )
213+
214+ return DynUNetSkipLayer (index , downsample = downsamples [0 ], upsample = upsamples [0 ], next_layer = next_layer )
215+
216+ if not self .deep_supervision :
217+ self .skip_layers = create_skips (
218+ 0 , [self .input_block ] + list (self .downsamples ), self .upsamples [::- 1 ], self .bottleneck
219+ )
220+ else :
221+ self .skip_layers = create_skips (
222+ 0 ,
223+ [self .input_block ] + list (self .downsamples ),
224+ self .upsamples [::- 1 ],
225+ self .bottleneck ,
226+ superheads = self .deep_supervision_heads ,
227+ )
207228
208229 def check_kernel_stride (self ):
209230 kernels , strides = self .kernel_size , self .strides
@@ -242,8 +263,7 @@ def forward(self, x):
242263 out = self .output_block (out )
243264 if self .training and self .deep_supervision :
244265 out_all = [out ]
245- feature_maps = self .heads [1 : self .deep_supr_num + 1 ]
246- for feature_map in feature_maps :
266+ for feature_map in self .heads :
247267 out_all .append (interpolate (feature_map , out .shape [2 :]))
248268 return torch .stack (out_all , dim = 1 )
249269 return out
@@ -334,7 +354,7 @@ def get_module_list(
334354 return nn .ModuleList (layers )
335355
336356 def get_deep_supervision_heads (self ):
337- return nn .ModuleList ([self .get_output_block (i + 1 ) for i in range (len ( self .upsamples ) - 1 )])
357+ return nn .ModuleList ([self .get_output_block (i + 1 ) for i in range (self .deep_supr_num )])
338358
339359 @staticmethod
340360 def initialize_weights (module ):
0 commit comments