Skip to content

Commit c861230

Browse files
authored
Update unet3d.py
1 parent 9d9eb66 commit c861230

File tree

1 file changed

+5
-3
lines changed
  • src/aind_exaspim_image_compression/machine_learning

1 file changed

+5
-3
lines changed

src/aind_exaspim_image_compression/machine_learning/unet3d.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class DoubleConv(nn.Module):
115115
activations.
116116
"""
117117

118-
def __init__(self, in_channels, out_channels, mid_channels=None):
118+
def __init__(self, in_channels, out_channels, kernel_size=3, mid_channels=None):
119119
"""
120120
Instantiates a DoubleConv object.
121121
@@ -125,6 +125,8 @@ def __init__(self, in_channels, out_channels, mid_channels=None):
125125
Number of input channels to this module.
126126
out_channels : int
127127
Number of output channels produced by this module.
128+
kernel_size : int, optional
129+
Size of kernel used in convolutional layers. Default is 3.
128130
mid_channels : int, optional
129131
Number of channels in the intermediate convolution. Default is
130132
None.
@@ -138,10 +140,10 @@ def __init__(self, in_channels, out_channels, mid_channels=None):
138140

139141
# Instance attributes
140142
self.double_conv = nn.Sequential(
141-
nn.Conv3d(in_channels, mid_channels, kernel_size=4, padding=1),
143+
nn.Conv3d(in_channels, mid_channels, kernel_size=kernel_size, padding=1),
142144
nn.BatchNorm3d(mid_channels),
143145
nn.LeakyReLU(negative_slope=0.01, inplace=True),
144-
nn.Conv3d(mid_channels, out_channels, kernel_size=4, padding=1),
146+
nn.Conv3d(mid_channels, out_channels, kernel_size=kernel_size, padding=1),
145147
nn.BatchNorm3d(out_channels),
146148
nn.LeakyReLU(negative_slope=0.01, inplace=True)
147149
)

0 commit comments

Comments
 (0)