1
1
import numpy as np
2
2
import torch
3
3
import ttach as tta
4
- from typing import Callable , List , Tuple
4
+ from typing import Callable , List , Tuple , Optional
5
5
from pytorch_grad_cam .activations_and_gradients import ActivationsAndGradients
6
6
from pytorch_grad_cam .utils .svd_on_activations import get_2d_projection
7
7
from pytorch_grad_cam .utils .image import scale_cam_image
@@ -15,7 +15,8 @@ def __init__(self,
15
15
use_cuda : bool = False ,
16
16
reshape_transform : Callable = None ,
17
17
compute_input_gradient : bool = False ,
18
- uses_gradients : bool = True ) -> None :
18
+ uses_gradients : bool = True ,
19
+ tta_transforms : Optional [tta .Compose ] = None ) -> None :
19
20
self .model = model .eval ()
20
21
self .target_layers = target_layers
21
22
self .cuda = use_cuda
@@ -24,6 +25,16 @@ def __init__(self,
24
25
self .reshape_transform = reshape_transform
25
26
self .compute_input_gradient = compute_input_gradient
26
27
self .uses_gradients = uses_gradients
28
+ if tta_transforms is None :
29
+ self .tta_transforms = tta .Compose (
30
+ [
31
+ tta .HorizontalFlip (),
32
+ tta .Multiply (factors = [0.9 , 1 , 1.1 ]),
33
+ ]
34
+ )
35
+ else :
36
+ self .tta_transforms = tta_transforms
37
+
27
38
self .activations_and_grads = ActivationsAndGradients (
28
39
self .model , target_layers , reshape_transform )
29
40
@@ -148,14 +159,8 @@ def forward_augmentation_smoothing(self,
148
159
input_tensor : torch .Tensor ,
149
160
targets : List [torch .nn .Module ],
150
161
eigen_smooth : bool = False ) -> np .ndarray :
151
- transforms = tta .Compose (
152
- [
153
- tta .HorizontalFlip (),
154
- tta .Multiply (factors = [0.9 , 1 , 1.1 ]),
155
- ]
156
- )
157
162
cams = []
158
- for transform in transforms :
163
+ for transform in self . tta_transforms :
159
164
augmented_tensor = transform .augment_image (input_tensor )
160
165
cam = self .forward (augmented_tensor ,
161
166
targets ,
0 commit comments