1+ """
2+ Created on Fri Aug 14 15:00:00 2025
3+
4+ @author: Anna Grim
5+ @email: anna.grim@alleninstitute.org
6+
7+ Code that implements a 3D U-Net.
8+
9+ """
10+
111import torch
212import torch .nn as nn
313import torch .nn .functional as F
414
515
616class UNet (nn .Module ):
7- def __init__ (self , width_multiplier = 1 , trilinear = True ):
17+ """
18+ 3D U-Net architecture for 3D image data, suitable for tasks such as
19+ denoising or segmentation.
20+
21+ Attributes
22+ ----------
23+ channels : List[int]
24+ Number of channels in each layer after applying "width_multiplier".
25+ trilinear : bool
26+ Flag indicating whether trilinear upsampling is used.
27+ inc : DoubleConv
28+ Initial convolution block.
29+ down1, down2, down3, down4 : Down
30+ Downsampling blocks in the encoder path.
31+ up1, up2, up3, up4 : Up
32+ Upsampling blocks in the decoder path.
33+ outc : OutConv
34+ Final 1x1x1 convolution mapping features to the output channel.
35+ """
36+
37+ def __init__ (self , width_multiplier = 1 , trilinear = True , use_relu = True ):
38+ """
39+ Instantiates a UNet object.
40+
41+ Parameters
42+ ----------
43+ width_multiplier : float, optional
44+ Factor that scales the number of channels in each layer. Default
45+ is 1.
46+ trilinear : bool, optional
47+ If True, use trilinear interpolation for upsampling in decoder
48+ blocks; otherwise, use transposed convolutions. Default is True.
49+ use_relu : bool, optional
50+ If True, use ReLU activations in `DoubleConv` blocks; otherwise,
51+ use LeakyReLU. Default is True.
52+ """
853 # Call parent class
954 super (UNet , self ).__init__ ()
1055
@@ -17,7 +62,7 @@ def __init__(self, width_multiplier=1, trilinear=True):
1762 self .trilinear = trilinear
1863
1964 # Contracting layers
20- self .inc = DoubleConv (1 , self .channels [0 ])
65+ self .inc = DoubleConv (1 , self .channels [0 ], use_relu = use_relu )
2166 self .down1 = Down (self .channels [0 ], self .channels [1 ])
2267 self .down2 = Down (self .channels [1 ], self .channels [2 ])
2368 self .down3 = Down (self .channels [2 ], self .channels [3 ])
@@ -31,6 +76,20 @@ def __init__(self, width_multiplier=1, trilinear=True):
3176 self .outc = OutConv (self .channels [0 ], 1 )
3277
3378 def forward (self , x ):
79+ """
80+ Forward pass of the 3D U-Net.
81+
82+ Parameters
83+ ----------
84+ x : torch.Tensor
85+ Input tensor with shape (B, 1, D, H, W).
86+
87+ Returns
88+ -------
89+ torch.Tensor
90+ Output tensor with shape (B, 1, D, H, W), representing the
91+ denoised image.
92+ """
3493 # Contracting layers
3594 x1 = self .inc (x )
3695 x2 = self .down1 (x1 )
@@ -48,43 +107,155 @@ def forward(self, x):
48107
49108
50109class DoubleConv (nn .Module ):
51- """(convolution => [BN] => ReLU) * 2"""
110+ """
111+ A module that consists of two consecutive 3D convolutional layers, each
112+ followed by batch normalization and a nonlinear activation.
52113
53- def __init__ (self , in_channels , out_channels , mid_channels = None ):
114+ Attributes
115+ ----------
116+ double_conv : nn.Sequential
117+ Sequential module containing two convolutions, batch norms, and
118+ activations.
119+ """
120+
121+ def __init__ (
122+ self , in_channels , out_channels , mid_channels = None , use_relu = True
123+ ):
124+ """
125+ Instantiates a DoubleConv object.
126+
127+ Parameters
128+ ----------
129+ in_channels : int
130+ Number of input channels to this module.
131+ out_channels : int
132+ Number of output channels produced by this module.
133+ mid_channels : int, optional
134+ Number of channels in the intermediate convolution. Default is
135+ None.
136+ use_relu : bool, optional
137+ If True, use ReLU activations; otherwise use LeakyReLU. Default
138+ is True.
139+ """
140+ # Call parent class
54141 super ().__init__ ()
142+
143+ # Check whether to set custom mid channel dimension
55144 if not mid_channels :
56145 mid_channels = out_channels
146+
147+ # Set nonlinear activation
148+ if use_relu :
149+ activation = nn .ReLU (inplace = True )
150+ else :
151+ activation = nn .LeakyReLU (negative_slope = 0.01 , inplace = True )
152+
153+ # Instance attributes
57154 self .double_conv = nn .Sequential (
58155 nn .Conv3d (in_channels , mid_channels , kernel_size = 3 , padding = 1 ),
59156 nn .BatchNorm3d (mid_channels ),
60- nn . ReLU ( inplace = True ) ,
157+ activation ,
61158 nn .Conv3d (mid_channels , out_channels , kernel_size = 3 , padding = 1 ),
62159 nn .BatchNorm3d (out_channels ),
63- nn . ReLU ( inplace = True ),
160+ activation
64161 )
65162
66163 def forward (self , x ):
164+ """
165+ Forward pass of the double convolution module.
166+
167+ Parameters
168+ ----------
169+ x : torch.Tensor
170+ Input tensor with shape (B, C, D, H, W).
171+
172+ Returns
173+ -------
174+ torch.Tensor
175+ Output tensor after double convolution.
176+ """
67177 return self .double_conv (x )
68178
69179
70180class Down (nn .Module ):
71- """Downscaling with maxpool then double conv"""
181+ """
182+ A downsampling module for a 3D U-Net.
183+
184+ Attributes
185+ ----------
186+ maxpool_conv : nn.Sequential
187+ Sequential module containing a MaxPool3d layer followed by a
188+ DoubleConv block.
189+ """
72190
73191 def __init__ (self , in_channels , out_channels ):
192+ """
193+ Instantiates a Down object.
194+
195+ Parameters
196+ ----------
197+ in_channels : int
198+ Number of input channels to this module.
199+ out_channels : int
200+ Number of output channels produced by this module.
201+ """
202+ # Call parent class
74203 super ().__init__ ()
204+
205+ # Instance attributes
75206 self .maxpool_conv = nn .Sequential (
76207 nn .MaxPool3d (2 ), DoubleConv (in_channels , out_channels )
77208 )
78209
79210 def forward (self , x ):
211+ """
212+ Forward pass of the downsampling block.
213+
214+ Parameters
215+ ----------
216+ x : torch.Tensor
217+ Input tensor with shape (B, C, D, H, W).
218+
219+ Returns
220+ -------
221+ torch.Tensor
222+ Output tensor after max pooling and double convolution.
223+ """
80224 return self .maxpool_conv (x )
81225
82226
83227class Up (nn .Module ):
84- """Upscaling then double conv"""
228+ """
229+ An upsampling block for a 3D U-Net that performs spatial upscaling
230+ followed by a double convolution.
231+
232+ Attributes
233+ ----------
234+ up : nn.Module
235+ Upsampling layer (either nn.Upsample or nn.ConvTranspose3d).
236+ conv : DoubleConv
237+ Double convolution block applied after concatenating the skip
238+ connection.
239+ """
85240
86241 def __init__ (self , in_channels , out_channels , trilinear = True ):
242+ """
243+ Instantiates an Up object.
244+
245+ Parameters
246+ ----------
247+ in_channels : int
248+ Number of input channels to this module.
249+ out_channels : int
250+ Number of output channels produced by this module.
251+ trilinear : bool, optional
252+ Indication of whether to use nn.Upsample or nn.ConvTranspose3d.
253+ Default is True, meaning that nn.Upsample is used.
254+ """
255+ # Call parent class
87256 super ().__init__ ()
257+
258+ # Instance attributes
88259 if trilinear :
89260 self .up = nn .Upsample (
90261 scale_factor = 2 , mode = "trilinear" , align_corners = True
@@ -99,8 +270,26 @@ def __init__(self, in_channels, out_channels, trilinear=True):
99270 self .conv = DoubleConv (in_channels , out_channels )
100271
101272 def forward (self , x1 , x2 ):
273+ """
274+ Forward pass of the upsampling block in a 3D U-Net.
275+
276+ Parameters
277+ ----------
278+ x1 : torch.Tensor
279+ Input tensor from the previous decoder layer with shape
280+ (B, C1, D, H1, W1).
281+ x2 : torch.Tensor
282+ Skip connection tensor from the encoder path with shape
283+ (B, C2, D, H2, W2).
284+
285+ Returns
286+ -------
287+ torch.Tensor
288+ Output tensor after upsampling, concatenation with the skip
289+ connection, and double convolution. The output shape is
290+ (B, out_channels, D, H2, W2).
291+ """
102292 x1 = self .up (x1 )
103- # input is CHW
104293 diffY = x2 .size ()[2 ] - x1 .size ()[2 ]
105294 diffX = x2 .size ()[3 ] - x1 .size ()[3 ]
106295
@@ -113,29 +302,47 @@ def forward(self, x1, x2):
113302
114303
115304class OutConv (nn .Module ):
305+ """
306+ Final output convolution layer for a 3D U-Net.
307+
308+ Attributes
309+ ----------
310+ conv : nn.Conv3d
311+ 1x1x1 convolution that maps the feature channels to the output
312+ channels.
313+ """
314+
116315 def __init__ (self , in_channels , out_channels ):
316+ """
317+ Instantiates an OutConv object.
318+
319+ Parameters
320+ ----------
321+ in_channels : int
322+ Number of input channels to this module.
323+ out_channels : int
324+ Number of output channels produced by this module.
325+ """
326+ # Call parent class
117327 super (OutConv , self ).__init__ ()
328+
329+ # Instance attributes
118330 self .conv = nn .Conv3d (in_channels , out_channels , kernel_size = 1 )
119331
120332 def forward (self , x ):
121- return self .conv (x )
333+ """
334+ Forward pass of the output convolution.
122335
336+ Parameters
337+ ----------
338+ x : torch.Tensor
339+ Input tensor from the last decoder layer with shape
340+ (B, C, D, H, W).
123341
124- class DepthwiseSeparableConv3d (nn .Module ):
125- def __init__ (self , nin , nout , kernel_size , padding , kernels_per_layer = 1 ):
126- super (DepthwiseSeparableConv3d , self ).__init__ ()
127- self .depthwise = nn .Conv3d (
128- nin ,
129- nin * kernels_per_layer ,
130- kernel_size = kernel_size ,
131- padding = padding ,
132- groups = nin ,
133- )
134- self .pointwise = nn .Conv3d (
135- nin * kernels_per_layer , nout , kernel_size = 1
136- )
137-
138- def forward (self , x ):
139- out = self .depthwise (x )
140- out = self .pointwise (out )
141- return out
342+ Returns
343+ -------
344+ torch.Tensor
345+ Output tensor after 1x1x1 convolution, with shape
346+ (B, 1, D, H, W).
347+ """
348+ return self .conv (x )
0 commit comments