Skip to content

Commit 0711e04

Browse files
authored
previous dynunet (#2534)
Signed-off-by: Wenqi Li <[email protected]>
1 parent e8124da commit 0711e04

File tree

3 files changed

+420
-0
lines changed

3 files changed

+420
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Sequence, Union
13+
14+
import numpy as np
15+
import torch.nn as nn
16+
17+
from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_conv_layer
18+
from monai.networks.layers.factories import Norm
19+
from monai.networks.layers.utils import get_act_layer
20+
21+
22+
class _UnetResBlockV1(UnetResBlock):
23+
"""
24+
UnetResBlock for backward compatibility purpose.
25+
"""
26+
27+
def __init__(
28+
self,
29+
spatial_dims: int,
30+
in_channels: int,
31+
out_channels: int,
32+
kernel_size: Union[Sequence[int], int],
33+
stride: Union[Sequence[int], int],
34+
norm_name: str,
35+
):
36+
nn.Module.__init__(self)
37+
self.conv1 = get_conv_layer(
38+
spatial_dims,
39+
in_channels,
40+
out_channels,
41+
kernel_size=kernel_size,
42+
stride=stride,
43+
conv_only=True,
44+
)
45+
self.conv2 = get_conv_layer(
46+
spatial_dims,
47+
out_channels,
48+
out_channels,
49+
kernel_size=kernel_size,
50+
stride=1,
51+
conv_only=True,
52+
)
53+
self.conv3 = get_conv_layer(
54+
spatial_dims,
55+
in_channels,
56+
out_channels,
57+
kernel_size=1,
58+
stride=stride,
59+
conv_only=True,
60+
)
61+
self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01}))
62+
self.norm1 = _get_norm_layer(spatial_dims, out_channels, norm_name)
63+
self.norm2 = _get_norm_layer(spatial_dims, out_channels, norm_name)
64+
self.norm3 = _get_norm_layer(spatial_dims, out_channels, norm_name)
65+
self.downsample = in_channels != out_channels
66+
stride_np = np.atleast_1d(stride)
67+
if not np.all(stride_np == 1):
68+
self.downsample = True
69+
70+
71+
class _UnetBasicBlockV1(UnetBasicBlock):
72+
"""
73+
UnetBasicBlock for backward compatibility purpose.
74+
"""
75+
76+
def __init__(
77+
self,
78+
spatial_dims: int,
79+
in_channels: int,
80+
out_channels: int,
81+
kernel_size: Union[Sequence[int], int],
82+
stride: Union[Sequence[int], int],
83+
norm_name: str,
84+
):
85+
nn.Module.__init__(self)
86+
self.conv1 = get_conv_layer(
87+
spatial_dims,
88+
in_channels,
89+
out_channels,
90+
kernel_size=kernel_size,
91+
stride=stride,
92+
conv_only=True,
93+
)
94+
self.conv2 = get_conv_layer(
95+
spatial_dims,
96+
out_channels,
97+
out_channels,
98+
kernel_size=kernel_size,
99+
stride=1,
100+
conv_only=True,
101+
)
102+
self.lrelu = get_act_layer(("leakyrelu", {"inplace": True, "negative_slope": 0.01}))
103+
self.norm1 = _get_norm_layer(spatial_dims, out_channels, norm_name)
104+
self.norm2 = _get_norm_layer(spatial_dims, out_channels, norm_name)
105+
106+
107+
class _UnetUpBlockV1(UnetUpBlock):
108+
"""
109+
UnetUpBlock for backward compatibility purpose.
110+
"""
111+
112+
def __init__(
113+
self,
114+
spatial_dims: int,
115+
in_channels: int,
116+
out_channels: int,
117+
kernel_size: Union[Sequence[int], int],
118+
stride: Union[Sequence[int], int],
119+
upsample_kernel_size: Union[Sequence[int], int],
120+
norm_name: str,
121+
):
122+
nn.Module.__init__(self)
123+
upsample_stride = upsample_kernel_size
124+
self.transp_conv = get_conv_layer(
125+
spatial_dims,
126+
in_channels,
127+
out_channels,
128+
kernel_size=upsample_kernel_size,
129+
stride=upsample_stride,
130+
conv_only=True,
131+
is_transposed=True,
132+
)
133+
self.conv_block = _UnetBasicBlockV1(
134+
spatial_dims,
135+
out_channels + out_channels,
136+
out_channels,
137+
kernel_size=kernel_size,
138+
stride=1,
139+
norm_name=norm_name,
140+
)
141+
142+
143+
def _get_norm_layer(spatial_dims: int, out_channels: int, norm_name: str, num_groups: int = 16):
144+
if norm_name not in ["batch", "instance", "group"]:
145+
raise ValueError(f"Unsupported normalization mode: {norm_name}")
146+
if norm_name == "group":
147+
if out_channels % num_groups != 0:
148+
raise AssertionError("out_channels should be divisible by num_groups.")
149+
norm = Norm[norm_name, spatial_dims](num_groups=num_groups, num_channels=out_channels, affine=True)
150+
else:
151+
norm = Norm[norm_name, spatial_dims](out_channels, affine=True)
152+
return norm

monai/networks/nets/dynunet_v1.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
from typing import List, Sequence, Union
14+
15+
import torch
16+
import torch.nn as nn
17+
18+
from monai.networks.blocks.dynunet_block_v1 import _UnetBasicBlockV1, _UnetResBlockV1, _UnetUpBlockV1
19+
from monai.networks.nets.dynunet import DynUNet, DynUNetSkipLayer
20+
from monai.utils import deprecated
21+
22+
__all__ = ["DynUNetV1", "DynUnetV1", "DynunetV1"]
23+
24+
25+
@deprecated(
26+
since="0.6.0",
27+
removed="0.7.0",
28+
msg_suffix="This module is for backward compatibility purpose only. Please use `DynUNet` instead.",
29+
)
30+
class DynUNetV1(DynUNet):
31+
"""
32+
This a deprecated reimplementation of a dynamic UNet (DynUNet), please use `monai.networks.nets.DynUNet` instead.
33+
34+
Args:
35+
spatial_dims: number of spatial dimensions.
36+
in_channels: number of input channels.
37+
out_channels: number of output channels.
38+
kernel_size: convolution kernel size.
39+
strides: convolution strides for each blocks.
40+
upsample_kernel_size: convolution kernel size for transposed convolution layers.
41+
norm_name: [``"batch"``, ``"instance"``, ``"group"``]. Defaults to "instance".
42+
deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.
43+
deep_supr_num: number of feature maps that will output during deep supervision head. Defaults to 1.
44+
res_block: whether to use residual connection based convolution blocks during the network.
45+
Defaults to ``False``.
46+
"""
47+
48+
def __init__(
49+
self,
50+
spatial_dims: int,
51+
in_channels: int,
52+
out_channels: int,
53+
kernel_size: Sequence[Union[Sequence[int], int]],
54+
strides: Sequence[Union[Sequence[int], int]],
55+
upsample_kernel_size: Sequence[Union[Sequence[int], int]],
56+
norm_name: str = "instance",
57+
deep_supervision: bool = False,
58+
deep_supr_num: int = 1,
59+
res_block: bool = False,
60+
):
61+
nn.Module.__init__(self)
62+
self.spatial_dims = spatial_dims
63+
self.in_channels = in_channels
64+
self.out_channels = out_channels
65+
self.kernel_size = kernel_size
66+
self.strides = strides
67+
self.upsample_kernel_size = upsample_kernel_size
68+
self.norm_name = norm_name
69+
self.conv_block = _UnetResBlockV1 if res_block else _UnetBasicBlockV1 # type: ignore
70+
self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))]
71+
self.input_block = self.get_input_block()
72+
self.downsamples = self.get_downsamples()
73+
self.bottleneck = self.get_bottleneck()
74+
self.upsamples = self.get_upsamples()
75+
self.output_block = self.get_output_block(0)
76+
self.deep_supervision = deep_supervision
77+
self.deep_supervision_heads = self.get_deep_supervision_heads()
78+
self.deep_supr_num = deep_supr_num
79+
self.apply(self.initialize_weights)
80+
self.check_kernel_stride()
81+
self.check_deep_supr_num()
82+
83+
# initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
84+
self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1)
85+
86+
def create_skips(index, downsamples, upsamples, superheads, bottleneck):
87+
"""
88+
Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is
89+
done recursively from the top down since a recursive nn.Module subclass is being used to be compatible
90+
with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads`
91+
since the `input_block` is passed to this function as the first item in `downsamples`, however this
92+
shouldn't be associated with a supervision head.
93+
"""
94+
95+
if len(downsamples) != len(upsamples):
96+
raise AssertionError(f"{len(downsamples)} != {len(upsamples)}")
97+
if (len(downsamples) - len(superheads)) not in (1, 0):
98+
raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}")
99+
100+
if len(downsamples) == 0: # bottom of the network, pass the bottleneck block
101+
return bottleneck
102+
if index == 0: # don't associate a supervision head with self.input_block
103+
current_head, rest_heads = nn.Identity(), superheads
104+
elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one
105+
current_head, rest_heads = nn.Identity(), superheads[1:]
106+
else:
107+
current_head, rest_heads = superheads[0], superheads[1:]
108+
109+
# create the next layer down, this will stop at the bottleneck layer
110+
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck)
111+
112+
return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer)
113+
114+
self.skip_layers = create_skips(
115+
0,
116+
[self.input_block] + list(self.downsamples),
117+
self.upsamples[::-1],
118+
self.deep_supervision_heads,
119+
self.bottleneck,
120+
)
121+
122+
def get_upsamples(self):
123+
inp, out = self.filters[1:][::-1], self.filters[:-1][::-1]
124+
strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1]
125+
upsample_kernel_size = self.upsample_kernel_size[::-1]
126+
return self.get_module_list(inp, out, kernel_size, strides, _UnetUpBlockV1, upsample_kernel_size)
127+
128+
@staticmethod
129+
def initialize_weights(module):
130+
name = module.__class__.__name__.lower()
131+
if "conv3d" in name or "conv2d" in name:
132+
nn.init.kaiming_normal_(module.weight, a=0.01)
133+
if module.bias is not None:
134+
nn.init.constant_(module.bias, 0)
135+
elif "norm" in name:
136+
nn.init.normal_(module.weight, 1.0, 0.02)
137+
nn.init.zeros_(module.bias)
138+
139+
140+
DynUnetV1 = DynunetV1 = DynUNetV1

0 commit comments

Comments
 (0)