8
8
from lightning .pytorch import LightningModule
9
9
from matplotlib .pyplot import get_cmap
10
10
from monai .optimizers import WarmupCosineSchedule
11
- from monai .transforms import DivisiblePad
11
+ from monai .transforms import DivisiblePad , Rotate90
12
12
from skimage .exposure import rescale_intensity
13
13
from torch import Tensor , nn
14
14
from torch .nn import functional as F
@@ -114,6 +114,10 @@ class VSUNet(LightningModule):
114
114
:param bool test_evaluate_cellpose:
115
115
evaluate the performance of the CellPose model instead of the trained model
116
116
in test stage, defaults to False
117
+ :param bool test_time_augmentations:
118
+ apply test time augmentations in test stage, defaults to False
119
+ :param Literal['mean', 'median', 'product'] tta_type:
120
+ type of test time augmentations aggregation, defaults to "mean"
117
121
"""
118
122
119
123
def __init__ (
@@ -131,6 +135,8 @@ def __init__(
131
135
test_cellpose_model_path : str = None ,
132
136
test_cellpose_diameter : float = None ,
133
137
test_evaluate_cellpose : bool = False ,
138
+ test_time_augmentations : bool = False ,
139
+ tta_type : Literal ["mean" , "median" , "product" ] = "mean" ,
134
140
) -> None :
135
141
super ().__init__ ()
136
142
net_class = _UNET_ARCHITECTURE .get (architecture )
@@ -163,7 +169,10 @@ def __init__(
163
169
self .test_cellpose_model_path = test_cellpose_model_path
164
170
self .test_cellpose_diameter = test_cellpose_diameter
165
171
self .test_evaluate_cellpose = test_evaluate_cellpose
172
+ self .test_time_augmentations = test_time_augmentations
173
+ self .tta_type = tta_type
166
174
self .freeze_encoder = freeze_encoder
175
+ self ._original_shape_yx = None
167
176
if ckpt_path is not None :
168
177
self .load_state_dict (
169
178
torch .load (ckpt_path )["state_dict" ]
@@ -316,8 +325,50 @@ def _log_segmentation_metrics(
316
325
)
317
326
318
327
def predict_step (self , batch : Sample , batch_idx : int , dataloader_idx : int = 0 ):
319
- source = self ._predict_pad (batch ["source" ])
320
- return self ._predict_pad .inverse (self .forward (source ))
328
+ source = batch ["source" ]
329
+ if self .test_time_augmentations :
330
+ prediction = self .perform_test_time_augmentations (source )
331
+ else :
332
+ source = self ._predict_pad (source )
333
+ prediction = self .forward (source )
334
+ prediction = self ._predict_pad .inverse (prediction )
335
+
336
+ return prediction
337
+
338
+ def perform_test_time_augmentations (self , source : Tensor ) -> Tensor :
339
+ """Perform test time augmentations on the input source
340
+ and aggregate the predictions using the specified method.
341
+
342
+ :param source: input tensor
343
+ :return: aggregated prediction
344
+ """
345
+
346
+ # Save the yx coords to crop post rotations
347
+ self ._original_shape_yx = source .shape [- 2 :]
348
+ predictions = []
349
+ for i in range (4 ):
350
+ augmented = self ._rotate_volume (source , k = i , spatial_axes = (1 , 2 ))
351
+ augmented = self ._predict_pad (augmented )
352
+ augmented_prediction = self .forward (augmented )
353
+ de_augmented_prediction = self ._predict_pad .inverse (augmented_prediction )
354
+ de_augmented_prediction = self ._rotate_volume (
355
+ de_augmented_prediction , k = 4 - i , spatial_axes = (1 , 2 )
356
+ )
357
+ de_augmented_prediction = self ._crop_to_original (de_augmented_prediction )
358
+
359
+ # Undo rotation and padding
360
+ predictions .append (de_augmented_prediction )
361
+
362
+ if self .tta_type == "mean" :
363
+ prediction = torch .stack (predictions ).mean (dim = 0 )
364
+ elif self .tta_type == "median" :
365
+ prediction = torch .stack (predictions ).median (dim = 0 ).values
366
+ elif self .tta_type == "product" :
367
+ # Perform multiplication of predictions in logarithmic space for numerical stability adding epsion to avoid log(0) case
368
+ log_predictions = torch .stack ([torch .log (p + 1e-9 ) for p in predictions ])
369
+ log_prediction_sum = log_predictions .sum (dim = 0 )
370
+ prediction = torch .exp (log_prediction_sum )
371
+ return prediction
321
372
322
373
def on_train_epoch_end (self ):
323
374
self ._log_samples ("train_samples" , self .training_step_outputs )
@@ -404,6 +455,33 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
404
455
key , grid , self .current_epoch , dataformats = "HWC"
405
456
)
406
457
458
+ def _rotate_volume (self , tensor : Tensor , k : int , spatial_axes : tuple ) -> Tensor :
459
+ # Padding to ensure square shape
460
+ max_dim = max (tensor .shape [- 2 ], tensor .shape [- 1 ])
461
+ pad_transform = DivisiblePad ((0 , 0 , max_dim , max_dim ))
462
+ padded_tensor = pad_transform (tensor )
463
+
464
+ # Rotation
465
+ rotated_tensor = []
466
+ rotate = Rotate90 (k = k , spatial_axes = spatial_axes )
467
+ for b in range (padded_tensor .shape [0 ]): # iterate over batch
468
+ rotated_tensor .append (rotate (padded_tensor [b ]))
469
+
470
+ # Stack the list of tensors back into a single tensor
471
+ rotated_tensor = torch .stack (rotated_tensor )
472
+ del padded_tensor
473
+ # # Cropping to original shape
474
+ return rotated_tensor
475
+
476
+ def _crop_to_original (self , tensor : Tensor ) -> Tensor :
477
+ original_y , original_x = self ._original_shape_yx
478
+ pad_y = (tensor .shape [- 2 ] - original_y ) // 2
479
+ pad_x = (tensor .shape [- 1 ] - original_x ) // 2
480
+ cropped_tensor = tensor [
481
+ ..., pad_y : pad_y + original_y , pad_x : pad_x + original_x
482
+ ]
483
+ return cropped_tensor
484
+
407
485
408
486
class FcmaeUNet (VSUNet ):
409
487
def __init__ (self , fit_mask_ratio : float = 0.0 , ** kwargs ):
0 commit comments