Skip to content

Commit 6f930ee

Browse files
authored
Merge pull request #5 from AllenNeuralDynamics/refactor-update-train
doc: unet3d
2 parents d020812 + 1965fab commit 6f930ee

File tree

2 files changed

+236
-29
lines changed

2 files changed

+236
-29
lines changed

src/aind_exaspim_image_compression/machine_learning/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959

6060
self.codec = blosc.Blosc(cname="zstd", clevel=5, shuffle=blosc.SHUFFLE)
6161
self.criterion = nn.L1Loss()
62-
self.model = UNet().to("cuda")
62+
self.model = UNet(use_relu=False).to("cuda")
6363
self.optimizer = optim.AdamW(self.model.parameters(), lr=lr)
6464
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=25)
6565
self.writer = SummaryWriter(log_dir=log_dir)

src/aind_exaspim_image_compression/machine_learning/unet3d.py

Lines changed: 235 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,55 @@
1+
"""
2+
Created on Fri Aug 14 15:00:00 2025
3+
4+
@author: Anna Grim
5+
@email: anna.grim@alleninstitute.org
6+
7+
Code that implements a 3D U-Net.
8+
9+
"""
10+
111
import torch
212
import torch.nn as nn
313
import torch.nn.functional as F
414

515

616
class UNet(nn.Module):
7-
def __init__(self, width_multiplier=1, trilinear=True):
17+
"""
18+
3D U-Net architecture for 3D image data, suitable for tasks such as
19+
denoising or segmentation.
20+
21+
Attributes
22+
----------
23+
channels : List[int]
24+
Number of channels in each layer after applying "width_multiplier".
25+
trilinear : bool
26+
Flag indicating whether trilinear upsampling is used.
27+
inc : DoubleConv
28+
Initial convolution block.
29+
down1, down2, down3, down4 : Down
30+
Downsampling blocks in the encoder path.
31+
up1, up2, up3, up4 : Up
32+
Upsampling blocks in the decoder path.
33+
outc : OutConv
34+
Final 1x1x1 convolution mapping features to the output channel.
35+
"""
36+
37+
def __init__(self, width_multiplier=1, trilinear=True, use_relu=True):
38+
"""
39+
Instantiates a UNet object.
40+
41+
Parameters
42+
----------
43+
width_multiplier : float, optional
44+
Factor that scales the number of channels in each layer. Default
45+
is 1.
46+
trilinear : bool, optional
47+
If True, use trilinear interpolation for upsampling in decoder
48+
blocks; otherwise, use transposed convolutions. Default is True.
49+
use_relu : bool, optional
50+
If True, use ReLU activations in `DoubleConv` blocks; otherwise,
51+
use LeakyReLU. Default is True.
52+
"""
853
# Call parent class
954
super(UNet, self).__init__()
1055

@@ -17,7 +62,7 @@ def __init__(self, width_multiplier=1, trilinear=True):
1762
self.trilinear = trilinear
1863

1964
# Contracting layers
20-
self.inc = DoubleConv(1, self.channels[0])
65+
self.inc = DoubleConv(1, self.channels[0], use_relu=use_relu)
2166
self.down1 = Down(self.channels[0], self.channels[1])
2267
self.down2 = Down(self.channels[1], self.channels[2])
2368
self.down3 = Down(self.channels[2], self.channels[3])
@@ -31,6 +76,20 @@ def __init__(self, width_multiplier=1, trilinear=True):
3176
self.outc = OutConv(self.channels[0], 1)
3277

3378
def forward(self, x):
79+
"""
80+
Forward pass of the 3D U-Net.
81+
82+
Parameters
83+
----------
84+
x : torch.Tensor
85+
Input tensor with shape (B, 1, D, H, W).
86+
87+
Returns
88+
-------
89+
torch.Tensor
90+
Output tensor with shape (B, 1, D, H, W), representing the
91+
denoised image.
92+
"""
3493
# Contracting layers
3594
x1 = self.inc(x)
3695
x2 = self.down1(x1)
@@ -48,43 +107,155 @@ def forward(self, x):
48107

49108

50109
class DoubleConv(nn.Module):
51-
"""(convolution => [BN] => ReLU) * 2"""
110+
"""
111+
A module that consists of two consecutive 3D convolutional layers, each
112+
followed by batch normalization and a nonlinear activation.
52113
53-
def __init__(self, in_channels, out_channels, mid_channels=None):
114+
Attributes
115+
----------
116+
double_conv : nn.Sequential
117+
Sequential module containing two convolutions, batch norms, and
118+
activations.
119+
"""
120+
121+
def __init__(
122+
self, in_channels, out_channels, mid_channels=None, use_relu=True
123+
):
124+
"""
125+
Instantiates a DoubleConv object.
126+
127+
Parameters
128+
----------
129+
in_channels : int
130+
Number of input channels to this module.
131+
out_channels : int
132+
Number of output channels produced by this module.
133+
mid_channels : int, optional
134+
Number of channels in the intermediate convolution. Default is
135+
None.
136+
use_relu : bool, optional
137+
If True, use ReLU activations; otherwise use LeakyReLU. Default
138+
is True.
139+
"""
140+
# Call parent class
54141
super().__init__()
142+
143+
# Check whether to set custom mid channel dimension
55144
if not mid_channels:
56145
mid_channels = out_channels
146+
147+
# Set nonlinear activation
148+
if use_relu:
149+
activation = nn.ReLU(inplace=True)
150+
else:
151+
activation = nn.LeakyReLU(negative_slope=0.01, inplace=True)
152+
153+
# Instance attributes
57154
self.double_conv = nn.Sequential(
58155
nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
59156
nn.BatchNorm3d(mid_channels),
60-
nn.ReLU(inplace=True),
157+
activation,
61158
nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
62159
nn.BatchNorm3d(out_channels),
63-
nn.ReLU(inplace=True),
160+
activation
64161
)
65162

66163
def forward(self, x):
164+
"""
165+
Forward pass of the double convolution module.
166+
167+
Parameters
168+
----------
169+
x : torch.Tensor
170+
Input tensor with shape (B, C, D, H, W).
171+
172+
Returns
173+
-------
174+
torch.Tensor
175+
Output tensor after double convolution.
176+
"""
67177
return self.double_conv(x)
68178

69179

70180
class Down(nn.Module):
71-
"""Downscaling with maxpool then double conv"""
181+
"""
182+
A downsampling module for a 3D U-Net.
183+
184+
Attributes
185+
----------
186+
maxpool_conv : nn.Sequential
187+
Sequential module containing a MaxPool3d layer followed by a
188+
DoubleConv block.
189+
"""
72190

73191
def __init__(self, in_channels, out_channels):
192+
"""
193+
Instantiates a Down object.
194+
195+
Parameters
196+
----------
197+
in_channels : int
198+
Number of input channels to this module.
199+
out_channels : int
200+
Number of output channels produced by this module.
201+
"""
202+
# Call parent class
74203
super().__init__()
204+
205+
# Instance attributes
75206
self.maxpool_conv = nn.Sequential(
76207
nn.MaxPool3d(2), DoubleConv(in_channels, out_channels)
77208
)
78209

79210
def forward(self, x):
211+
"""
212+
Forward pass of the downsampling block.
213+
214+
Parameters
215+
----------
216+
x : torch.Tensor
217+
Input tensor with shape (B, C, D, H, W).
218+
219+
Returns
220+
-------
221+
torch.Tensor
222+
Output tensor after max pooling and double convolution.
223+
"""
80224
return self.maxpool_conv(x)
81225

82226

83227
class Up(nn.Module):
84-
"""Upscaling then double conv"""
228+
"""
229+
An upsampling block for a 3D U-Net that performs spatial upscaling
230+
followed by a double convolution.
231+
232+
Attributes
233+
----------
234+
up : nn.Module
235+
Upsampling layer (either nn.Upsample or nn.ConvTranspose3d).
236+
conv : DoubleConv
237+
Double convolution block applied after concatenating the skip
238+
connection.
239+
"""
85240

86241
def __init__(self, in_channels, out_channels, trilinear=True):
242+
"""
243+
Instantiates an Up object.
244+
245+
Parameters
246+
----------
247+
in_channels : int
248+
Number of input channels to this module.
249+
out_channels : int
250+
Number of output channels produced by this module.
251+
trilinear : bool, optional
252+
Indication of whether to use nn.Upsample or nn.ConvTranspose3d.
253+
Default is True, meaning that nn.Upsample is used.
254+
"""
255+
# Call parent class
87256
super().__init__()
257+
258+
# Instance attributes
88259
if trilinear:
89260
self.up = nn.Upsample(
90261
scale_factor=2, mode="trilinear", align_corners=True
@@ -99,8 +270,26 @@ def __init__(self, in_channels, out_channels, trilinear=True):
99270
self.conv = DoubleConv(in_channels, out_channels)
100271

101272
def forward(self, x1, x2):
273+
"""
274+
Forward pass of the upsampling block in a 3D U-Net.
275+
276+
Parameters
277+
----------
278+
x1 : torch.Tensor
279+
Input tensor from the previous decoder layer with shape
280+
(B, C1, D, H1, W1).
281+
x2 : torch.Tensor
282+
Skip connection tensor from the encoder path with shape
283+
(B, C2, D, H2, W2).
284+
285+
Returns
286+
-------
287+
torch.Tensor
288+
Output tensor after upsampling, concatenation with the skip
289+
connection, and double convolution. The output shape is
290+
(B, out_channels, D, H2, W2).
291+
"""
102292
x1 = self.up(x1)
103-
# input is CHW
104293
diffY = x2.size()[2] - x1.size()[2]
105294
diffX = x2.size()[3] - x1.size()[3]
106295

@@ -113,29 +302,47 @@ def forward(self, x1, x2):
113302

114303

115304
class OutConv(nn.Module):
305+
"""
306+
Final output convolution layer for a 3D U-Net.
307+
308+
Attributes
309+
----------
310+
conv : nn.Conv3d
311+
1x1x1 convolution that maps the feature channels to the output
312+
channels.
313+
"""
314+
116315
def __init__(self, in_channels, out_channels):
316+
"""
317+
Instantiates an OutConv object.
318+
319+
Parameters
320+
----------
321+
in_channels : int
322+
Number of input channels to this module.
323+
out_channels : int
324+
Number of output channels produced by this module.
325+
"""
326+
# Call parent class
117327
super(OutConv, self).__init__()
328+
329+
# Instance attributes
118330
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
119331

120332
def forward(self, x):
121-
return self.conv(x)
333+
"""
334+
Forward pass of the output convolution.
122335
336+
Parameters
337+
----------
338+
x : torch.Tensor
339+
Input tensor from the last decoder layer with shape
340+
(B, C, D, H, W).
123341
124-
class DepthwiseSeparableConv3d(nn.Module):
125-
def __init__(self, nin, nout, kernel_size, padding, kernels_per_layer=1):
126-
super(DepthwiseSeparableConv3d, self).__init__()
127-
self.depthwise = nn.Conv3d(
128-
nin,
129-
nin * kernels_per_layer,
130-
kernel_size=kernel_size,
131-
padding=padding,
132-
groups=nin,
133-
)
134-
self.pointwise = nn.Conv3d(
135-
nin * kernels_per_layer, nout, kernel_size=1
136-
)
137-
138-
def forward(self, x):
139-
out = self.depthwise(x)
140-
out = self.pointwise(out)
141-
return out
342+
Returns
343+
-------
344+
torch.Tensor
345+
Output tensor after 1x1x1 convolution, with shape
346+
(B, 1, D, H, W).
347+
"""
348+
return self.conv(x)

0 commit comments

Comments
 (0)