Commit 404b365
Fix(UNet): Remove top level skip connection for N2V2 and reduce number of parameters in decoder (#503)
## Description
> [!NOTE]
> **tldr**: The top level skip connection was still being concatenated
for N2V2 due to an incorrect shape comparison that was the test for if
the skip connection should be added or not. Additionally the number of
input and output features in the decoder blocks has been modified to
more closely match the original n2v package implementation.
### Background - why do we need this PR?
The erroneous skip connection was causing poor performance for n2v2 and
is an incorrect implementation. The decoder in CAREamics had many more
parameters than in the original n2v package because each decoder block
had double the input and output features than in the n2v package. It
seems the shape comparison test for the skip connection was probably
wrong because the output features were double what was expected.
### Overview - what changed?
The test to add the skip connection is now based on the layer index
rather than comparing the shape of the input and the skip connection
features. The calculation for the number of input and output channels in
the UNet decoder has been modified to match the n2v tensorflow
implementation.
Additionally, a small fix to the next-gen dataset random patching has
been made so patches can cover the final row/column of pixels. The
relevant test has been updated accordingly.
### Implementation - how did you implement the changes?
The input features is now the initial input features multiplied by `(2
** (depth - n - 1))` rather than`(2 ** (depth - n)`.
The test to add a skip connection is now `if (not self.n2v2) or
(self.n2v2 and (i // 2 < depth - 1))`.
## Changes Made
### Modified features or files
<!-- List important modified features or files. -->
- `UnetDecoder.__init__`
- `UnetDecoder.forward`
-
`careamics.dataset_ng.patching_strategies.random_patching._generate_random_patches`
- test `test_random_coords`
## How has this been tested?
All tests pass. And in notebooks.
## Additional Notes and Examples
Tensor flow model summary for N2V2 config
```
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input (InputLayer) [(None, None, None, 1)] 0 []
channel_0down_level_0_no_0 (None, None, None, 32) 320 ['input[0][0]']
(Conv2D)
batch_normalization (Batch (None, None, None, 32) 128 ['channel_0down_level_0_no_0[0
Normalization) ][0]']
activation (Activation) (None, None, None, 32) 0 ['batch_normalization[0][0]']
channel_0down_level_0_no_1 (None, None, None, 32) 9248 ['activation[0][0]']
(Conv2D)
batch_normalization_1 (Bat (None, None, None, 32) 128 ['channel_0down_level_0_no_1[0
chNormalization) ][0]']
activation_1 (Activation) (None, None, None, 32) 0 ['batch_normalization_1[0][0]'
]
channel_0max_0 (MaxBlurPoo (None, None, None, 32) 288 ['activation_1[0][0]']
l2D)
channel_0down_level_1_no_0 (None, None, None, 64) 18496 ['channel_0max_0[0][0]']
(Conv2D)
batch_normalization_2 (Bat (None, None, None, 64) 256 ['channel_0down_level_1_no_0[0
chNormalization) ][0]']
activation_2 (Activation) (None, None, None, 64) 0 ['batch_normalization_2[0][0]'
]
channel_0down_level_1_no_1 (None, None, None, 64) 36928 ['activation_2[0][0]']
(Conv2D)
batch_normalization_3 (Bat (None, None, None, 64) 256 ['channel_0down_level_1_no_1[0
chNormalization) ][0]']
activation_3 (Activation) (None, None, None, 64) 0 ['batch_normalization_3[0][0]'
]
channel_0max_1 (MaxBlurPoo (None, None, None, 64) 576 ['activation_3[0][0]']
l2D)
channel_0middle_0 (Conv2D) (None, None, None, 128) 73856 ['channel_0max_1[0][0]']
batch_normalization_4 (Bat (None, None, None, 128) 512 ['channel_0middle_0[0][0]']
chNormalization)
activation_4 (Activation) (None, None, None, 128) 0 ['batch_normalization_4[0][0]'
]
channel_0middle_2 (Conv2D) (None, None, None, 64) 73792 ['activation_4[0][0]']
batch_normalization_5 (Bat (None, None, None, 64) 256 ['channel_0middle_2[0][0]']
chNormalization)
activation_5 (Activation) (None, None, None, 64) 0 ['batch_normalization_5[0][0]'
]
up_sampling2d (UpSampling2 (None, None, None, 64) 0 ['activation_5[0][0]']
D)
concatenate (Concatenate) (None, None, None, 128) 0 ['up_sampling2d[0][0]',
'activation_3[0][0]']
channel_0up_level_1_no_0 ( (None, None, None, 64) 73792 ['concatenate[0][0]']
Conv2D)
batch_normalization_6 (Bat (None, None, None, 64) 256 ['channel_0up_level_1_no_0[0][
chNormalization) 0]']
activation_6 (Activation) (None, None, None, 64) 0 ['batch_normalization_6[0][0]'
]
channel_0up_level_1_no_2 ( (None, None, None, 32) 18464 ['activation_6[0][0]']
Conv2D)
batch_normalization_7 (Bat (None, None, None, 32) 128 ['channel_0up_level_1_no_2[0][
chNormalization) 0]']
activation_7 (Activation) (None, None, None, 32) 0 ['batch_normalization_7[0][0]'
]
up_sampling2d_1 (UpSamplin (None, None, None, 32) 0 ['activation_7[0][0]']
g2D)
channel_0up_level_0_no_0 ( (None, None, None, 32) 9248 ['up_sampling2d_1[0][0]']
Conv2D)
batch_normalization_8 (Bat (None, None, None, 32) 128 ['channel_0up_level_0_no_0[0][
chNormalization) 0]']
activation_8 (Activation) (None, None, None, 32) 0 ['batch_normalization_8[0][0]'
]
channel_0up_level_0_no_2 ( (None, None, None, 32) 9248 ['activation_8[0][0]']
Conv2D)
batch_normalization_9 (Bat (None, None, None, 32) 128 ['channel_0up_level_0_no_2[0][
chNormalization) 0]']
activation_9 (Activation) (None, None, None, 32) 0 ['batch_normalization_9[0][0]'
]
conv2d (Conv2D) (None, None, None, 1) 33 ['activation_9[0][0]']
activation_10 (Activation) (None, None, None, 1) 0 ['conv2d[0][0]']
==================================================================================================
Total params: 326465 (1.25 MB)
Trainable params: 324513 (1.24 MB)
Non-trainable params: 1952 (7.62 KB)
__________________________________________________________________________________________________
```
PyTorch model with fixes from this PR
```
UNet(
(encoder): UnetEncoder(
(pooling): MaxBlurPool()
(encoder_blocks): ModuleList(
(0): Conv_Block(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(batch_norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(batch_norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(1): MaxBlurPool()
(2): Conv_Block(
(conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(batch_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(batch_norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(3): MaxBlurPool()
)
)
(decoder): UnetDecoder(
(bottleneck): Conv_Block(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(batch_norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(batch_norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(decoder_blocks): ModuleList(
(0): Upsample(scale_factor=2.0, mode='bilinear')
(1): Conv_Block(
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(batch_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(batch_norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
(2): Upsample(scale_factor=2.0, mode='bilinear')
(3): Conv_Block(
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(batch_norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(batch_norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation): ReLU()
)
)
)
(final_conv): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
(final_activation): Identity()
)
```
```
| Name | Type | Params | Mode
-----------------------------------------------------
0 | model | UNet | 324 K | train
1 | metrics | MetricCollection | 0 | train
-----------------------------------------------------
324 K Trainable params
0 Non-trainable params
324 K Total params
1.298 Total estimated model params size (MB)
41 Modules in train mode
0 Modules in eval mode
```
### Additional differences between tensorflow and CAREamics
implementation
- tensorflow and pytorch default parameters for batch norm are different
- CAREamics uses bilinear or trilinear upsampling in the decoder whereas
tensorflow uses nearest neighbour
- In tensorflow N2V doesn't use an intermediate multiplier in the
decoder but N2V2 does, I will open an issue about this.
---
**Please ensure your PR meets the following requirements:**
- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
---------
Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com>1 parent 3647b14 commit 404b365
File tree
3 files changed
+21
-12
lines changed- src/careamics
- dataset_ng/patching_strategies
- models
- tests/dataset_ng/patching_strategies
3 files changed
+21
-12
lines changedLines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
302 | 302 | | |
303 | 303 | | |
304 | 304 | | |
305 | | - | |
| 305 | + | |
306 | 306 | | |
307 | 307 | | |
308 | 308 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
205 | 205 | | |
206 | 206 | | |
207 | 207 | | |
208 | | - | |
209 | | - | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
210 | 215 | | |
211 | 216 | | |
212 | 217 | | |
213 | | - | |
214 | | - | |
215 | | - | |
| 218 | + | |
216 | 219 | | |
217 | | - | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
218 | 225 | | |
219 | 226 | | |
220 | 227 | | |
| |||
241 | 248 | | |
242 | 249 | | |
243 | 250 | | |
| 251 | + | |
244 | 252 | | |
245 | 253 | | |
246 | 254 | | |
| |||
249 | 257 | | |
250 | 258 | | |
251 | 259 | | |
252 | | - | |
253 | | - | |
254 | | - | |
255 | | - | |
| 260 | + | |
| 261 | + | |
256 | 262 | | |
257 | 263 | | |
258 | 264 | | |
| |||
Lines changed: 4 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
65 | 65 | | |
66 | 66 | | |
67 | 67 | | |
| 68 | + | |
68 | 69 | | |
69 | 70 | | |
| 71 | + | |
70 | 72 | | |
71 | 73 | | |
72 | 74 | | |
| |||
76 | 78 | | |
77 | 79 | | |
78 | 80 | | |
79 | | - | |
| 81 | + | |
| 82 | + | |
80 | 83 | | |
81 | 84 | | |
82 | 85 | | |
| |||
0 commit comments