1
+ from typing import Callable , List , Optional , Tuple
2
+
1
3
import numpy as np
2
4
import torch
3
5
import ttach as tta
4
- from typing import Callable , List , Tuple , Optional
6
+
5
7
from pytorch_grad_cam .activations_and_gradients import ActivationsAndGradients
6
- from pytorch_grad_cam .utils .svd_on_activations import get_2d_projection
7
8
from pytorch_grad_cam .utils .image import scale_cam_image
8
9
from pytorch_grad_cam .utils .model_targets import ClassifierOutputTarget
10
+ from pytorch_grad_cam .utils .svd_on_activations import get_2d_projection
9
11
10
12
11
13
class BaseCAM :
12
- def __init__ (self ,
13
- model : torch .nn .Module ,
14
- target_layers : List [torch .nn .Module ],
15
- reshape_transform : Callable = None ,
16
- compute_input_gradient : bool = False ,
17
- uses_gradients : bool = True ,
18
- tta_transforms : Optional [tta .Compose ] = None ) -> None :
14
+ def __init__ (
15
+ self ,
16
+ model : torch .nn .Module ,
17
+ target_layers : List [torch .nn .Module ],
18
+ reshape_transform : Callable = None ,
19
+ compute_input_gradient : bool = False ,
20
+ uses_gradients : bool = True ,
21
+ tta_transforms : Optional [tta .Compose ] = None ,
22
+ ) -> None :
19
23
self .model = model .eval ()
20
24
self .target_layers = target_layers
21
25
@@ -34,63 +38,64 @@ def __init__(self,
34
38
else :
35
39
self .tta_transforms = tta_transforms
36
40
37
- self .activations_and_grads = ActivationsAndGradients (
38
- self .model , target_layers , reshape_transform )
41
+ self .activations_and_grads = ActivationsAndGradients (self .model , target_layers , reshape_transform )
39
42
40
43
""" Get a vector of weights for every channel in the target layer.
41
44
Methods that return weights channels,
42
45
will typically need to only implement this function. """
43
46
44
- def get_cam_weights (self ,
45
- input_tensor : torch .Tensor ,
46
- target_layers : List [torch .nn .Module ],
47
- targets : List [torch .nn .Module ],
48
- activations : torch .Tensor ,
49
- grads : torch .Tensor ) -> np .ndarray :
47
+ def get_cam_weights (
48
+ self ,
49
+ input_tensor : torch .Tensor ,
50
+ target_layers : List [torch .nn .Module ],
51
+ targets : List [torch .nn .Module ],
52
+ activations : torch .Tensor ,
53
+ grads : torch .Tensor ,
54
+ ) -> np .ndarray :
50
55
raise Exception ("Not Implemented" )
51
56
52
- def get_cam_image (self ,
53
- input_tensor : torch .Tensor ,
54
- target_layer : torch .nn .Module ,
55
- targets : List [torch .nn .Module ],
56
- activations : torch .Tensor ,
57
- grads : torch .Tensor ,
58
- eigen_smooth : bool = False ) -> np .ndarray :
59
-
60
- weights = self .get_cam_weights (input_tensor ,
61
- target_layer ,
62
- targets ,
63
- activations ,
64
- grads )
65
- weighted_activations = weights [:, :, None , None ] * activations
57
+ def get_cam_image (
58
+ self ,
59
+ input_tensor : torch .Tensor ,
60
+ target_layer : torch .nn .Module ,
61
+ targets : List [torch .nn .Module ],
62
+ activations : torch .Tensor ,
63
+ grads : torch .Tensor ,
64
+ eigen_smooth : bool = False ,
65
+ ) -> np .ndarray :
66
+ weights = self .get_cam_weights (input_tensor , target_layer , targets , activations , grads )
67
+ # 2D conv
68
+ if len (activations .shape ) == 4 :
69
+ weighted_activations = weights [:, :, None , None ] * activations
70
+ # 3D conv
71
+ elif len (activations .shape ) == 5 :
72
+ weighted_activations = weights [:, :, None , None , None ] * activations
73
+ else :
74
+ raise ValueError (f"Invalid activation shape. Get { len (activations .shape )} ." )
75
+
66
76
if eigen_smooth :
67
77
cam = get_2d_projection (weighted_activations )
68
78
else :
69
79
cam = weighted_activations .sum (axis = 1 )
70
80
return cam
71
81
72
- def forward (self ,
73
- input_tensor : torch .Tensor ,
74
- targets : List [torch .nn .Module ],
75
- eigen_smooth : bool = False ) -> np .ndarray :
76
-
82
+ def forward (
83
+ self , input_tensor : torch .Tensor , targets : List [torch .nn .Module ], eigen_smooth : bool = False
84
+ ) -> np .ndarray :
77
85
input_tensor = input_tensor .to (self .device )
78
86
79
87
if self .compute_input_gradient :
80
- input_tensor = torch .autograd .Variable (input_tensor ,
81
- requires_grad = True )
88
+ input_tensor = torch .autograd .Variable (input_tensor , requires_grad = True )
82
89
83
90
self .outputs = outputs = self .activations_and_grads (input_tensor )
84
91
85
92
if targets is None :
86
93
target_categories = np .argmax (outputs .cpu ().data .numpy (), axis = - 1 )
87
- targets = [ClassifierOutputTarget (
88
- category ) for category in target_categories ]
94
+ targets = [ClassifierOutputTarget (category ) for category in target_categories ]
89
95
90
96
if self .uses_gradients :
91
97
self .model .zero_grad ()
92
- loss = sum ([target (output )
93
- for target , output in zip (targets , outputs )])
98
+ loss = sum ([target (output ) for target , output in zip (targets , outputs )])
94
99
loss .backward (retain_graph = True )
95
100
96
101
# In most of the saliency attribution papers, the saliency is
@@ -102,25 +107,24 @@ def forward(self,
102
107
# This gives you more flexibility in case you just want to
103
108
# use all conv layers for example, all Batchnorm layers,
104
109
# or something else.
105
- cam_per_layer = self .compute_cam_per_layer (input_tensor ,
106
- targets ,
107
- eigen_smooth )
110
+ cam_per_layer = self .compute_cam_per_layer (input_tensor , targets , eigen_smooth )
108
111
return self .aggregate_multi_layers (cam_per_layer )
109
112
110
- def get_target_width_height (self ,
111
- input_tensor : torch .Tensor ) -> Tuple [int , int ]:
112
- width , height = input_tensor .size (- 1 ), input_tensor .size (- 2 )
113
- return width , height
113
+ def get_target_width_height (self , input_tensor : torch .Tensor ) -> Tuple [int , int ]:
114
+ if len (input_tensor .shape ) == 4 :
115
+ width , height = input_tensor .size (- 1 ), input_tensor .size (- 2 )
116
+ return width , height
117
+ elif len (input_tensor .shape ) == 5 :
118
+ depth , width , height = input_tensor .size (- 1 ), input_tensor .size (- 2 ), input_tensor .size (- 3 )
119
+ return depth , width , height
120
+ else :
121
+ raise ValueError ("Invalid input_tensor shape. Only 2D or 3D images are supported." )
114
122
115
123
def compute_cam_per_layer (
116
- self ,
117
- input_tensor : torch .Tensor ,
118
- targets : List [torch .nn .Module ],
119
- eigen_smooth : bool ) -> np .ndarray :
120
- activations_list = [a .cpu ().data .numpy ()
121
- for a in self .activations_and_grads .activations ]
122
- grads_list = [g .cpu ().data .numpy ()
123
- for g in self .activations_and_grads .gradients ]
124
+ self , input_tensor : torch .Tensor , targets : List [torch .nn .Module ], eigen_smooth : bool
125
+ ) -> np .ndarray :
126
+ activations_list = [a .cpu ().data .numpy () for a in self .activations_and_grads .activations ]
127
+ grads_list = [g .cpu ().data .numpy () for g in self .activations_and_grads .gradients ]
124
128
target_size = self .get_target_width_height (input_tensor )
125
129
126
130
cam_per_target_layer = []
@@ -134,36 +138,26 @@ def compute_cam_per_layer(
134
138
if i < len (grads_list ):
135
139
layer_grads = grads_list [i ]
136
140
137
- cam = self .get_cam_image (input_tensor ,
138
- target_layer ,
139
- targets ,
140
- layer_activations ,
141
- layer_grads ,
142
- eigen_smooth )
141
+ cam = self .get_cam_image (input_tensor , target_layer , targets , layer_activations , layer_grads , eigen_smooth )
143
142
cam = np .maximum (cam , 0 )
144
143
scaled = scale_cam_image (cam , target_size )
145
144
cam_per_target_layer .append (scaled [:, None , :])
146
145
147
146
return cam_per_target_layer
148
147
149
- def aggregate_multi_layers (
150
- self ,
151
- cam_per_target_layer : np .ndarray ) -> np .ndarray :
148
+ def aggregate_multi_layers (self , cam_per_target_layer : np .ndarray ) -> np .ndarray :
152
149
cam_per_target_layer = np .concatenate (cam_per_target_layer , axis = 1 )
153
150
cam_per_target_layer = np .maximum (cam_per_target_layer , 0 )
154
151
result = np .mean (cam_per_target_layer , axis = 1 )
155
152
return scale_cam_image (result )
156
153
157
- def forward_augmentation_smoothing (self ,
158
- input_tensor : torch .Tensor ,
159
- targets : List [torch .nn .Module ],
160
- eigen_smooth : bool = False ) -> np .ndarray :
154
+ def forward_augmentation_smoothing (
155
+ self , input_tensor : torch .Tensor , targets : List [torch .nn .Module ], eigen_smooth : bool = False
156
+ ) -> np .ndarray :
161
157
cams = []
162
158
for transform in self .tta_transforms :
163
159
augmented_tensor = transform .augment_image (input_tensor )
164
- cam = self .forward (augmented_tensor ,
165
- targets ,
166
- eigen_smooth )
160
+ cam = self .forward (augmented_tensor , targets , eigen_smooth )
167
161
168
162
# The ttach library expects a tensor of size BxCxHxW
169
163
cam = cam [:, None , :, :]
@@ -178,19 +172,18 @@ def forward_augmentation_smoothing(self,
178
172
cam = np .mean (np .float32 (cams ), axis = 0 )
179
173
return cam
180
174
181
- def __call__ (self ,
182
- input_tensor : torch .Tensor ,
183
- targets : List [torch .nn .Module ] = None ,
184
- aug_smooth : bool = False ,
185
- eigen_smooth : bool = False ) -> np .ndarray :
186
-
175
+ def __call__ (
176
+ self ,
177
+ input_tensor : torch .Tensor ,
178
+ targets : List [torch .nn .Module ] = None ,
179
+ aug_smooth : bool = False ,
180
+ eigen_smooth : bool = False ,
181
+ ) -> np .ndarray :
187
182
# Smooth the CAM result with test time augmentation
188
183
if aug_smooth is True :
189
- return self .forward_augmentation_smoothing (
190
- input_tensor , targets , eigen_smooth )
184
+ return self .forward_augmentation_smoothing (input_tensor , targets , eigen_smooth )
191
185
192
- return self .forward (input_tensor ,
193
- targets , eigen_smooth )
186
+ return self .forward (input_tensor , targets , eigen_smooth )
194
187
195
188
def __del__ (self ):
196
189
self .activations_and_grads .release ()
@@ -202,6 +195,5 @@ def __exit__(self, exc_type, exc_value, exc_tb):
202
195
self .activations_and_grads .release ()
203
196
if isinstance (exc_value , IndexError ):
204
197
# Handle IndexError here...
205
- print (
206
- f"An exception occurred in CAM with block: { exc_type } . Message: { exc_value } " )
198
+ print (f"An exception occurred in CAM with block: { exc_type } . Message: { exc_value } " )
207
199
return True
0 commit comments