|
| 1 | +# Copyright 2020 - 2021 MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | + |
| 13 | +from typing import List, Sequence, Union |
| 14 | + |
| 15 | +import torch |
| 16 | +import torch.nn as nn |
| 17 | + |
| 18 | +from monai.networks.blocks.dynunet_block_v1 import _UnetBasicBlockV1, _UnetResBlockV1, _UnetUpBlockV1 |
| 19 | +from monai.networks.nets.dynunet import DynUNet, DynUNetSkipLayer |
| 20 | +from monai.utils import deprecated |
| 21 | + |
| 22 | +__all__ = ["DynUNetV1", "DynUnetV1", "DynunetV1"] |
| 23 | + |
| 24 | + |
| 25 | +@deprecated( |
| 26 | + since="0.6.0", |
| 27 | + removed="0.7.0", |
| 28 | + msg_suffix="This module is for backward compatibility purpose only. Please use `DynUNet` instead.", |
| 29 | +) |
| 30 | +class DynUNetV1(DynUNet): |
| 31 | + """ |
| 32 | + This a deprecated reimplementation of a dynamic UNet (DynUNet), please use `monai.networks.nets.DynUNet` instead. |
| 33 | +
|
| 34 | + Args: |
| 35 | + spatial_dims: number of spatial dimensions. |
| 36 | + in_channels: number of input channels. |
| 37 | + out_channels: number of output channels. |
| 38 | + kernel_size: convolution kernel size. |
| 39 | + strides: convolution strides for each blocks. |
| 40 | + upsample_kernel_size: convolution kernel size for transposed convolution layers. |
| 41 | + norm_name: [``"batch"``, ``"instance"``, ``"group"``]. Defaults to "instance". |
| 42 | + deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. |
| 43 | + deep_supr_num: number of feature maps that will output during deep supervision head. Defaults to 1. |
| 44 | + res_block: whether to use residual connection based convolution blocks during the network. |
| 45 | + Defaults to ``False``. |
| 46 | + """ |
| 47 | + |
| 48 | + def __init__( |
| 49 | + self, |
| 50 | + spatial_dims: int, |
| 51 | + in_channels: int, |
| 52 | + out_channels: int, |
| 53 | + kernel_size: Sequence[Union[Sequence[int], int]], |
| 54 | + strides: Sequence[Union[Sequence[int], int]], |
| 55 | + upsample_kernel_size: Sequence[Union[Sequence[int], int]], |
| 56 | + norm_name: str = "instance", |
| 57 | + deep_supervision: bool = False, |
| 58 | + deep_supr_num: int = 1, |
| 59 | + res_block: bool = False, |
| 60 | + ): |
| 61 | + nn.Module.__init__(self) |
| 62 | + self.spatial_dims = spatial_dims |
| 63 | + self.in_channels = in_channels |
| 64 | + self.out_channels = out_channels |
| 65 | + self.kernel_size = kernel_size |
| 66 | + self.strides = strides |
| 67 | + self.upsample_kernel_size = upsample_kernel_size |
| 68 | + self.norm_name = norm_name |
| 69 | + self.conv_block = _UnetResBlockV1 if res_block else _UnetBasicBlockV1 # type: ignore |
| 70 | + self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] |
| 71 | + self.input_block = self.get_input_block() |
| 72 | + self.downsamples = self.get_downsamples() |
| 73 | + self.bottleneck = self.get_bottleneck() |
| 74 | + self.upsamples = self.get_upsamples() |
| 75 | + self.output_block = self.get_output_block(0) |
| 76 | + self.deep_supervision = deep_supervision |
| 77 | + self.deep_supervision_heads = self.get_deep_supervision_heads() |
| 78 | + self.deep_supr_num = deep_supr_num |
| 79 | + self.apply(self.initialize_weights) |
| 80 | + self.check_kernel_stride() |
| 81 | + self.check_deep_supr_num() |
| 82 | + |
| 83 | + # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on |
| 84 | + self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) |
| 85 | + |
| 86 | + def create_skips(index, downsamples, upsamples, superheads, bottleneck): |
| 87 | + """ |
| 88 | + Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is |
| 89 | + done recursively from the top down since a recursive nn.Module subclass is being used to be compatible |
| 90 | + with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads` |
| 91 | + since the `input_block` is passed to this function as the first item in `downsamples`, however this |
| 92 | + shouldn't be associated with a supervision head. |
| 93 | + """ |
| 94 | + |
| 95 | + if len(downsamples) != len(upsamples): |
| 96 | + raise AssertionError(f"{len(downsamples)} != {len(upsamples)}") |
| 97 | + if (len(downsamples) - len(superheads)) not in (1, 0): |
| 98 | + raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}") |
| 99 | + |
| 100 | + if len(downsamples) == 0: # bottom of the network, pass the bottleneck block |
| 101 | + return bottleneck |
| 102 | + if index == 0: # don't associate a supervision head with self.input_block |
| 103 | + current_head, rest_heads = nn.Identity(), superheads |
| 104 | + elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one |
| 105 | + current_head, rest_heads = nn.Identity(), superheads[1:] |
| 106 | + else: |
| 107 | + current_head, rest_heads = superheads[0], superheads[1:] |
| 108 | + |
| 109 | + # create the next layer down, this will stop at the bottleneck layer |
| 110 | + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) |
| 111 | + |
| 112 | + return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) |
| 113 | + |
| 114 | + self.skip_layers = create_skips( |
| 115 | + 0, |
| 116 | + [self.input_block] + list(self.downsamples), |
| 117 | + self.upsamples[::-1], |
| 118 | + self.deep_supervision_heads, |
| 119 | + self.bottleneck, |
| 120 | + ) |
| 121 | + |
| 122 | + def get_upsamples(self): |
| 123 | + inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] |
| 124 | + strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] |
| 125 | + upsample_kernel_size = self.upsample_kernel_size[::-1] |
| 126 | + return self.get_module_list(inp, out, kernel_size, strides, _UnetUpBlockV1, upsample_kernel_size) |
| 127 | + |
| 128 | + @staticmethod |
| 129 | + def initialize_weights(module): |
| 130 | + name = module.__class__.__name__.lower() |
| 131 | + if "conv3d" in name or "conv2d" in name: |
| 132 | + nn.init.kaiming_normal_(module.weight, a=0.01) |
| 133 | + if module.bias is not None: |
| 134 | + nn.init.constant_(module.bias, 0) |
| 135 | + elif "norm" in name: |
| 136 | + nn.init.normal_(module.weight, 1.0, 0.02) |
| 137 | + nn.init.zeros_(module.bias) |
| 138 | + |
| 139 | + |
| 140 | +DynUnetV1 = DynunetV1 = DynUNetV1 |
0 commit comments