Skip to content

Commit 82c88cc

Browse files
authored
Use custom test-time augmentations (#201)
1 parent 25067d1 commit 82c88cc

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

pytorch_grad_cam/base_cam.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import torch
33
import ttach as tta
4-
from typing import Callable, List, Tuple
4+
from typing import Callable, List, Tuple, Optional
55
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
66
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
77
from pytorch_grad_cam.utils.image import scale_cam_image
@@ -15,7 +15,8 @@ def __init__(self,
1515
use_cuda: bool = False,
1616
reshape_transform: Callable = None,
1717
compute_input_gradient: bool = False,
18-
uses_gradients: bool = True) -> None:
18+
uses_gradients: bool = True,
19+
tta_transforms: Optional[tta.Compose] = None) -> None:
1920
self.model = model.eval()
2021
self.target_layers = target_layers
2122
self.cuda = use_cuda
@@ -24,6 +25,16 @@ def __init__(self,
2425
self.reshape_transform = reshape_transform
2526
self.compute_input_gradient = compute_input_gradient
2627
self.uses_gradients = uses_gradients
28+
if tta_transforms is None:
29+
self.tta_transforms = tta.Compose(
30+
[
31+
tta.HorizontalFlip(),
32+
tta.Multiply(factors=[0.9, 1, 1.1]),
33+
]
34+
)
35+
else:
36+
self.tta_transforms = tta_transforms
37+
2738
self.activations_and_grads = ActivationsAndGradients(
2839
self.model, target_layers, reshape_transform)
2940

@@ -148,14 +159,8 @@ def forward_augmentation_smoothing(self,
148159
input_tensor: torch.Tensor,
149160
targets: List[torch.nn.Module],
150161
eigen_smooth: bool = False) -> np.ndarray:
151-
transforms = tta.Compose(
152-
[
153-
tta.HorizontalFlip(),
154-
tta.Multiply(factors=[0.9, 1, 1.1]),
155-
]
156-
)
157162
cams = []
158-
for transform in transforms:
163+
for transform in self.tta_transforms:
159164
augmented_tensor = transform.augment_image(input_tensor)
160165
cam = self.forward(augmented_tensor,
161166
targets,

0 commit comments

Comments
 (0)