Skip to content

Commit 955da74

Browse files
authored
Test time augmentations (#91)
* first commit test time agumentations * the ttas probably belong here. * adding the rotations * fixing bug that compose() doesn't catch * fixes to deal with non-square inputs. the rotation functions were breaking * adding cropping of output to original * adding another TTA option for using the product of the stack. * revert the changes to hcs since we dont need them and removing uncessary methods in engine.py * formatting * ruff formatting * fixing docstring, abtracting tta method, removing inference to cpu. * fix missing variable in prediction using ttas * adding the rotations * ruff
1 parent 059ca38 commit 955da74

File tree

4 files changed

+109
-5
lines changed

4 files changed

+109
-5
lines changed

viscy/data/ctmc_v1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111

1212
class CTMCv1ValidationDataset(SlidingWindowDataset):
13-
1413
def __len__(self, subsample_rate: int = 30) -> int:
1514
# sample every 30th frame in the videos
1615
return super().__len__() // self.subsample_rate

viscy/data/hcs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,10 @@ def _setup_test(self, dataset_settings: dict):
453453
**dataset_settings,
454454
)
455455

456-
def _setup_predict(self, dataset_settings: dict):
456+
def _setup_predict(
457+
self,
458+
dataset_settings: dict,
459+
):
457460
"""Set up the predict stage."""
458461
# track metadata for inverting transform
459462
set_track_meta(True)

viscy/light/engine.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from lightning.pytorch import LightningModule
99
from matplotlib.pyplot import get_cmap
1010
from monai.optimizers import WarmupCosineSchedule
11-
from monai.transforms import DivisiblePad
11+
from monai.transforms import DivisiblePad, Rotate90
1212
from skimage.exposure import rescale_intensity
1313
from torch import Tensor, nn
1414
from torch.nn import functional as F
@@ -114,6 +114,10 @@ class VSUNet(LightningModule):
114114
:param bool test_evaluate_cellpose:
115115
evaluate the performance of the CellPose model instead of the trained model
116116
in test stage, defaults to False
117+
:param bool test_time_augmentations:
118+
apply test time augmentations in test stage, defaults to False
119+
:param Literal['mean', 'median', 'product'] tta_type:
120+
type of test time augmentations aggregation, defaults to "mean"
117121
"""
118122

119123
def __init__(
@@ -131,6 +135,8 @@ def __init__(
131135
test_cellpose_model_path: str = None,
132136
test_cellpose_diameter: float = None,
133137
test_evaluate_cellpose: bool = False,
138+
test_time_augmentations: bool = False,
139+
tta_type: Literal["mean", "median", "product"] = "mean",
134140
) -> None:
135141
super().__init__()
136142
net_class = _UNET_ARCHITECTURE.get(architecture)
@@ -163,7 +169,10 @@ def __init__(
163169
self.test_cellpose_model_path = test_cellpose_model_path
164170
self.test_cellpose_diameter = test_cellpose_diameter
165171
self.test_evaluate_cellpose = test_evaluate_cellpose
172+
self.test_time_augmentations = test_time_augmentations
173+
self.tta_type = tta_type
166174
self.freeze_encoder = freeze_encoder
175+
self._original_shape_yx = None
167176
if ckpt_path is not None:
168177
self.load_state_dict(
169178
torch.load(ckpt_path)["state_dict"]
@@ -316,8 +325,50 @@ def _log_segmentation_metrics(
316325
)
317326

318327
def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0):
319-
source = self._predict_pad(batch["source"])
320-
return self._predict_pad.inverse(self.forward(source))
328+
source = batch["source"]
329+
if self.test_time_augmentations:
330+
prediction = self.perform_test_time_augmentations(source)
331+
else:
332+
source = self._predict_pad(source)
333+
prediction = self.forward(source)
334+
prediction = self._predict_pad.inverse(prediction)
335+
336+
return prediction
337+
338+
def perform_test_time_augmentations(self, source: Tensor) -> Tensor:
339+
"""Perform test time augmentations on the input source
340+
and aggregate the predictions using the specified method.
341+
342+
:param source: input tensor
343+
:return: aggregated prediction
344+
"""
345+
346+
# Save the yx coords to crop post rotations
347+
self._original_shape_yx = source.shape[-2:]
348+
predictions = []
349+
for i in range(4):
350+
augmented = self._rotate_volume(source, k=i, spatial_axes=(1, 2))
351+
augmented = self._predict_pad(augmented)
352+
augmented_prediction = self.forward(augmented)
353+
de_augmented_prediction = self._predict_pad.inverse(augmented_prediction)
354+
de_augmented_prediction = self._rotate_volume(
355+
de_augmented_prediction, k=4 - i, spatial_axes=(1, 2)
356+
)
357+
de_augmented_prediction = self._crop_to_original(de_augmented_prediction)
358+
359+
# Undo rotation and padding
360+
predictions.append(de_augmented_prediction)
361+
362+
if self.tta_type == "mean":
363+
prediction = torch.stack(predictions).mean(dim=0)
364+
elif self.tta_type == "median":
365+
prediction = torch.stack(predictions).median(dim=0).values
366+
elif self.tta_type == "product":
367+
# Perform multiplication of predictions in logarithmic space for numerical stability adding epsion to avoid log(0) case
368+
log_predictions = torch.stack([torch.log(p + 1e-9) for p in predictions])
369+
log_prediction_sum = log_predictions.sum(dim=0)
370+
prediction = torch.exp(log_prediction_sum)
371+
return prediction
321372

322373
def on_train_epoch_end(self):
323374
self._log_samples("train_samples", self.training_step_outputs)
@@ -404,6 +455,33 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
404455
key, grid, self.current_epoch, dataformats="HWC"
405456
)
406457

458+
def _rotate_volume(self, tensor: Tensor, k: int, spatial_axes: tuple) -> Tensor:
459+
# Padding to ensure square shape
460+
max_dim = max(tensor.shape[-2], tensor.shape[-1])
461+
pad_transform = DivisiblePad((0, 0, max_dim, max_dim))
462+
padded_tensor = pad_transform(tensor)
463+
464+
# Rotation
465+
rotated_tensor = []
466+
rotate = Rotate90(k=k, spatial_axes=spatial_axes)
467+
for b in range(padded_tensor.shape[0]): # iterate over batch
468+
rotated_tensor.append(rotate(padded_tensor[b]))
469+
470+
# Stack the list of tensors back into a single tensor
471+
rotated_tensor = torch.stack(rotated_tensor)
472+
del padded_tensor
473+
# # Cropping to original shape
474+
return rotated_tensor
475+
476+
def _crop_to_original(self, tensor: Tensor) -> Tensor:
477+
original_y, original_x = self._original_shape_yx
478+
pad_y = (tensor.shape[-2] - original_y) // 2
479+
pad_x = (tensor.shape[-1] - original_x) // 2
480+
cropped_tensor = tensor[
481+
..., pad_y : pad_y + original_y, pad_x : pad_x + original_x
482+
]
483+
return cropped_tensor
484+
407485

408486
class FcmaeUNet(VSUNet):
409487
def __init__(self, fit_mask_ratio: float = 0.0, **kwargs):

viscy/light/predict_writer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,30 @@ def _resize_image(image: ImageArray, t_index: int, z_slice: slice) -> None:
3333

3434

3535
def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> NDArray:
36+
"""
37+
Blend a new stack of images into an old stack over a specified range of slices.
38+
39+
This function blends the `new_stack` of images into the `old_stack` over the range
40+
specified by `z_slice`. The blending is done using a weighted average where the
41+
weights are determined by the position within the range of slices. If the start
42+
of `z_slice` is 0, the function returns the `new_stack` unchanged.
43+
44+
Parameters:
45+
----------
46+
old_stack : NDArray
47+
The original stack of images to be blended.
48+
new_stack : NDArray
49+
The new stack of images to blend into the original stack.
50+
z_slice : slice
51+
A slice object indicating the range of slices over which to perform the blending.
52+
The start and stop attributes of the slice determine the range.
53+
54+
Returns:
55+
-------
56+
NDArray
57+
The blended stack of images. If `z_slice.start` is 0, returns `new_stack` unchanged.
58+
"""
59+
3660
if z_slice.start == 0:
3761
return new_stack
3862
depth = z_slice.stop - z_slice.start

0 commit comments

Comments
 (0)