Skip to content

Commit 404b365

Browse files
Fix(UNet): Remove top level skip connection for N2V2 and reduce number of parameters in decoder (#503)
## Description > [!NOTE] > **tldr**: The top level skip connection was still being concatenated for N2V2 due to an incorrect shape comparison that was the test for if the skip connection should be added or not. Additionally the number of input and output features in the decoder blocks has been modified to more closely match the original n2v package implementation. ### Background - why do we need this PR? The erroneous skip connection was causing poor performance for n2v2 and is an incorrect implementation. The decoder in CAREamics had many more parameters than in the original n2v package because each decoder block had double the input and output features than in the n2v package. It seems the shape comparison test for the skip connection was probably wrong because the output features were double what was expected. ### Overview - what changed? The test to add the skip connection is now based on the layer index rather than comparing the shape of the input and the skip connection features. The calculation for the number of input and output channels in the UNet decoder has been modified to match the n2v tensorflow implementation. Additionally, a small fix to the next-gen dataset random patching has been made so patches can cover the final row/column of pixels. The relevant test has been updated accordingly. ### Implementation - how did you implement the changes? The input features is now the initial input features multiplied by `(2 ** (depth - n - 1))` rather than`(2 ** (depth - n)`. The test to add a skip connection is now `if (not self.n2v2) or (self.n2v2 and (i // 2 < depth - 1))`. ## Changes Made ### Modified features or files <!-- List important modified features or files. --> - `UnetDecoder.__init__` - `UnetDecoder.forward` - `careamics.dataset_ng.patching_strategies.random_patching._generate_random_patches` - test `test_random_coords` ## How has this been tested? All tests pass. And in notebooks. ## Additional Notes and Examples Tensor flow model summary for N2V2 config ``` Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input (InputLayer) [(None, None, None, 1)] 0 [] channel_0down_level_0_no_0 (None, None, None, 32) 320 ['input[0][0]'] (Conv2D) batch_normalization (Batch (None, None, None, 32) 128 ['channel_0down_level_0_no_0[0 Normalization) ][0]'] activation (Activation) (None, None, None, 32) 0 ['batch_normalization[0][0]'] channel_0down_level_0_no_1 (None, None, None, 32) 9248 ['activation[0][0]'] (Conv2D) batch_normalization_1 (Bat (None, None, None, 32) 128 ['channel_0down_level_0_no_1[0 chNormalization) ][0]'] activation_1 (Activation) (None, None, None, 32) 0 ['batch_normalization_1[0][0]' ] channel_0max_0 (MaxBlurPoo (None, None, None, 32) 288 ['activation_1[0][0]'] l2D) channel_0down_level_1_no_0 (None, None, None, 64) 18496 ['channel_0max_0[0][0]'] (Conv2D) batch_normalization_2 (Bat (None, None, None, 64) 256 ['channel_0down_level_1_no_0[0 chNormalization) ][0]'] activation_2 (Activation) (None, None, None, 64) 0 ['batch_normalization_2[0][0]' ] channel_0down_level_1_no_1 (None, None, None, 64) 36928 ['activation_2[0][0]'] (Conv2D) batch_normalization_3 (Bat (None, None, None, 64) 256 ['channel_0down_level_1_no_1[0 chNormalization) ][0]'] activation_3 (Activation) (None, None, None, 64) 0 ['batch_normalization_3[0][0]' ] channel_0max_1 (MaxBlurPoo (None, None, None, 64) 576 ['activation_3[0][0]'] l2D) channel_0middle_0 (Conv2D) (None, None, None, 128) 73856 ['channel_0max_1[0][0]'] batch_normalization_4 (Bat (None, None, None, 128) 512 ['channel_0middle_0[0][0]'] chNormalization) activation_4 (Activation) (None, None, None, 128) 0 ['batch_normalization_4[0][0]' ] channel_0middle_2 (Conv2D) (None, None, None, 64) 73792 ['activation_4[0][0]'] batch_normalization_5 (Bat (None, None, None, 64) 256 ['channel_0middle_2[0][0]'] chNormalization) activation_5 (Activation) (None, None, None, 64) 0 ['batch_normalization_5[0][0]' ] up_sampling2d (UpSampling2 (None, None, None, 64) 0 ['activation_5[0][0]'] D) concatenate (Concatenate) (None, None, None, 128) 0 ['up_sampling2d[0][0]', 'activation_3[0][0]'] channel_0up_level_1_no_0 ( (None, None, None, 64) 73792 ['concatenate[0][0]'] Conv2D) batch_normalization_6 (Bat (None, None, None, 64) 256 ['channel_0up_level_1_no_0[0][ chNormalization) 0]'] activation_6 (Activation) (None, None, None, 64) 0 ['batch_normalization_6[0][0]' ] channel_0up_level_1_no_2 ( (None, None, None, 32) 18464 ['activation_6[0][0]'] Conv2D) batch_normalization_7 (Bat (None, None, None, 32) 128 ['channel_0up_level_1_no_2[0][ chNormalization) 0]'] activation_7 (Activation) (None, None, None, 32) 0 ['batch_normalization_7[0][0]' ] up_sampling2d_1 (UpSamplin (None, None, None, 32) 0 ['activation_7[0][0]'] g2D) channel_0up_level_0_no_0 ( (None, None, None, 32) 9248 ['up_sampling2d_1[0][0]'] Conv2D) batch_normalization_8 (Bat (None, None, None, 32) 128 ['channel_0up_level_0_no_0[0][ chNormalization) 0]'] activation_8 (Activation) (None, None, None, 32) 0 ['batch_normalization_8[0][0]' ] channel_0up_level_0_no_2 ( (None, None, None, 32) 9248 ['activation_8[0][0]'] Conv2D) batch_normalization_9 (Bat (None, None, None, 32) 128 ['channel_0up_level_0_no_2[0][ chNormalization) 0]'] activation_9 (Activation) (None, None, None, 32) 0 ['batch_normalization_9[0][0]' ] conv2d (Conv2D) (None, None, None, 1) 33 ['activation_9[0][0]'] activation_10 (Activation) (None, None, None, 1) 0 ['conv2d[0][0]'] ================================================================================================== Total params: 326465 (1.25 MB) Trainable params: 324513 (1.24 MB) Non-trainable params: 1952 (7.62 KB) __________________________________________________________________________________________________ ``` PyTorch model with fixes from this PR ``` UNet( (encoder): UnetEncoder( (pooling): MaxBlurPool() (encoder_blocks): ModuleList( (0): Conv_Block( (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (batch_norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (batch_norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (activation): ReLU() ) (1): MaxBlurPool() (2): Conv_Block( (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (batch_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (batch_norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (activation): ReLU() ) (3): MaxBlurPool() ) ) (decoder): UnetDecoder( (bottleneck): Conv_Block( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (batch_norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (batch_norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (activation): ReLU() ) (decoder_blocks): ModuleList( (0): Upsample(scale_factor=2.0, mode='bilinear') (1): Conv_Block( (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (batch_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (batch_norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (activation): ReLU() ) (2): Upsample(scale_factor=2.0, mode='bilinear') (3): Conv_Block( (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (batch_norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (batch_norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (activation): ReLU() ) ) ) (final_conv): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1)) (final_activation): Identity() ) ``` ``` | Name | Type | Params | Mode ----------------------------------------------------- 0 | model | UNet | 324 K | train 1 | metrics | MetricCollection | 0 | train ----------------------------------------------------- 324 K Trainable params 0 Non-trainable params 324 K Total params 1.298 Total estimated model params size (MB) 41 Modules in train mode 0 Modules in eval mode ``` ### Additional differences between tensorflow and CAREamics implementation - tensorflow and pytorch default parameters for batch norm are different - CAREamics uses bilinear or trilinear upsampling in the decoder whereas tensorflow uses nearest neighbour - In tensorflow N2V doesn't use an intermediate multiplier in the decoder but N2V2 does, I will open an issue about this. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>
1 parent 3647b14 commit 404b365

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

src/careamics/dataset_ng/patching_strategies/random_patching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _generate_random_coords(
302302
rng.integers(
303303
np.zeros(len(patch_size), dtype=int),
304304
np.array(spatial_shape) - np.array(patch_size),
305-
endpoint=False,
305+
endpoint=True,
306306
dtype=int,
307307
).tolist()
308308
)

src/careamics/models/unet.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,23 @@ def __init__(
205205
decoder_blocks: list[nn.Module] = []
206206
for n in range(depth):
207207
decoder_blocks.append(upsampling)
208-
in_channels = (num_channels_init * 2 ** (depth - n)) * groups
209-
out_channels = in_channels // 2
208+
209+
in_channels = (num_channels_init * 2 ** (depth - n - 1)) * groups
210+
# final decoder block has the same number in and out features
211+
out_channels = in_channels // 2 if n != depth - 1 else in_channels
212+
if not (n2v2 and (n == depth - 1)):
213+
in_channels = in_channels * 2 # accounting for skip connection concat
214+
210215
decoder_blocks.append(
211216
Conv_Block(
212217
conv_dim,
213-
in_channels=(
214-
in_channels + in_channels // 2 if n > 0 else in_channels
215-
),
218+
in_channels=in_channels,
216219
out_channels=out_channels,
217-
intermediate_channel_multiplier=2,
220+
# TODO: Tensorflow n2v implementation has intermediate channel
221+
# multiplication for skip_skipone=True but not skip_skipone=False
222+
# this needs to be benchmarked.
223+
# final decoder block doesn't multiply the intermediate features
224+
intermediate_channel_multiplier=2 if n != depth - 1 else 1,
218225
dropout_perc=dropout,
219226
activation="ReLU",
220227
use_batch_norm=use_batch_norm,
@@ -241,6 +248,7 @@ def forward(self, *features: torch.Tensor) -> torch.Tensor:
241248
"""
242249
x: torch.Tensor = features[0]
243250
skip_connections: tuple[torch.Tensor, ...] = features[-1:0:-1]
251+
depth = len(skip_connections)
244252

245253
x = self.bottleneck(x)
246254

@@ -249,10 +257,8 @@ def forward(self, *features: torch.Tensor) -> torch.Tensor:
249257
if isinstance(module, nn.Upsample):
250258
# divide index by 2 because of upsampling layers
251259
skip_connection: torch.Tensor = skip_connections[i // 2]
252-
if self.n2v2:
253-
if x.shape != skip_connections[-1].shape:
254-
x = self._interleave(x, skip_connection, self.groups)
255-
else:
260+
# top level skip connection not added for n2v2
261+
if (not self.n2v2) or (self.n2v2 and (i // 2 < depth - 1)):
256262
x = self._interleave(x, skip_connection, self.groups)
257263
return x
258264

tests/dataset_ng/patching_strategies/test_random_patching_ng.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ def test_n_patches_raises():
6565
[
6666
[(1, 1, 19, 37), (8, 8), 11],
6767
[(1, 1, 19, 37), (8, 5), 18],
68+
[(1, 1, 8, 8), (8, 8), 1],
6869
[(1, 1, 19, 37, 23), (8, 8, 8), 32],
6970
[(1, 1, 19, 37, 23), (8, 5, 7), 58],
71+
[(1, 1, 8, 8, 8), (8, 8, 8), 1],
7072
],
7173
)
7274
def test_random_coords(data_shape, patch_size, iterations):
@@ -76,7 +78,8 @@ def test_random_coords(data_shape, patch_size, iterations):
7678
coords = np.array(_generate_random_coords(spatial_shape, patch_size, rng))
7779
# validate patch is within spatial bounds
7880
assert (0 <= coords).all()
79-
assert (coords + patch_size < np.array(spatial_shape)).all()
81+
# less than or equal is correct bec this will be the stop of a slice expression
82+
assert (coords + patch_size <= np.array(spatial_shape)).all()
8083

8184

8285
def test_random_coords_raises():

0 commit comments

Comments
 (0)