1
1
from typing import Callable , List , Optional , Tuple
2
-
3
- import numpy as np
4
- import torch
5
2
from pytorch_grad_cam .base_cam import BaseCAM
6
- from scipy .signal import convolve2d
7
- from scipy .ndimage import gaussian_filter
8
- import cv2
9
-
10
- from pytorch_grad_cam .activations_and_gradients_no_detach import ActivationsAndGradients_no_detach
11
- from pytorch_grad_cam .utils .image import scale_cam_image
12
- from pytorch_grad_cam .utils .model_targets import ClassifierOutputTarget
13
- from pytorch_grad_cam .utils .svd_on_activations import get_2d_projection
3
+ import torch
4
+ import numpy as np
14
5
15
6
"""
16
7
Weighting the activation maps using Gradient and Hessian-Vector Product.
17
- This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods from a Shapley value perspective.
8
+ This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective.
18
9
"""
19
10
class ShapleyCAM (BaseCAM ):
20
11
def __init__ (self , model , target_layers ,
21
- reshape_transform = None ):
12
+ reshape_transform = None , detach = False ):
22
13
super (
23
14
ShapleyCAM ,
24
15
self ).__init__ (
25
16
model ,
26
17
target_layers ,
27
- reshape_transform )
28
-
29
- self .activations_and_grads = ActivationsAndGradients_no_detach (self .model , target_layers , reshape_transform )
18
+ reshape_transform ,
19
+ detach = detach )
30
20
31
21
def forward (
32
22
self , input_tensor : torch .Tensor , targets : List [torch .nn .Module ], eigen_smooth : bool = False
@@ -44,6 +34,7 @@ def forward(
44
34
if self .uses_gradients :
45
35
self .model .zero_grad ()
46
36
loss = sum ([target (output ) for target , output in zip (targets , outputs )])
37
+ # keep the graph
47
38
torch .autograd .grad (loss , input_tensor , retain_graph = True , create_graph = True )
48
39
49
40
# In most of the saliency attribution papers, the saliency is
@@ -65,96 +56,36 @@ def get_cam_weights(self,
65
56
target_category ,
66
57
activations ,
67
58
grads ):
68
- activations : List [Tensor ] # type: ignore[assignment]
69
- grads : List [Tensor ] # type: ignore[assignment]
70
-
59
+
71
60
hvp = torch .autograd .grad (
72
61
outputs = grads ,
73
62
inputs = activations ,
74
63
grad_outputs = activations ,
75
64
retain_graph = False ,
76
65
allow_unused = True
77
66
)[0 ]
78
- # print(torch.max(hvp[0]).item()) # verify that hvp is not all zeros
67
+ # print(torch.max(hvp[0]).item()) # Use .item() to get the scalar value
79
68
if hvp is None :
80
69
hvp = torch .tensor (0 ).to (self .device )
81
- elif self .activations_and_grads .reshape_transform is not None :
82
- hvp = self .activations_and_grads .reshape_transform (hvp )
70
+ else :
71
+ if self .activations_and_grads .reshape_transform is not None :
72
+ hvp = self .activations_and_grads .reshape_transform (hvp )
83
73
84
74
if self .activations_and_grads .reshape_transform is not None :
85
75
activations = self .activations_and_grads .reshape_transform (activations )
86
76
grads = self .activations_and_grads .reshape_transform (grads )
87
- weight = (grads - 0.5 * hvp ).cpu ().detach ().numpy ()
88
- activations = activations .cpu ().detach ().numpy ()
89
- grads = grads .cpu ().detach ().numpy ()
90
-
91
77
78
+ weight = (grads - 0.5 * hvp ).detach ().cpu ().numpy ()
92
79
# 2D image
93
80
if len (activations .shape ) == 4 :
94
81
weight = np .mean (weight , axis = (2 , 3 ))
95
- return weight , activations
82
+ return weight
96
83
97
84
# 3D image
98
85
elif len (activations .shape ) == 5 :
99
86
weight = np .mean (weight , axis = (2 , 3 , 4 ))
100
- return weight , activations
87
+ return weight
101
88
102
89
else :
103
90
raise ValueError ("Invalid grads shape."
104
91
"Shape of grads should be 4 (2D image) or 5 (3D image)." )
105
-
106
-
107
-
108
- def get_cam_image (
109
- self ,
110
- input_tensor : torch .Tensor ,
111
- target_layer : torch .nn .Module ,
112
- targets : List [torch .nn .Module ],
113
- activations : torch .Tensor ,
114
- grads : torch .Tensor ,
115
- eigen_smooth : bool = False ,
116
- ) -> np .ndarray :
117
- weights , activations = self .get_cam_weights (input_tensor , target_layer , targets , activations , grads )
118
-
119
- # 2D conv
120
- if len (activations .shape ) == 4 :
121
- weighted_activations = weights [:, :, None , None ] * activations
122
-
123
- # 3D conv
124
- elif len (activations .shape ) == 5 :
125
- weighted_activations = weights [:, :, None , None , None ] * activations
126
- else :
127
- raise ValueError (f"Invalid activation shape. Get { len (activations .shape )} ." )
128
-
129
- # weighted_activations = np.maximum(weighted_activations, 0)
130
- # weighted_activations = np.abs(weighted_activations)
131
- if eigen_smooth :
132
- cam = get_2d_projection (weighted_activations )
133
- else :
134
- cam = weighted_activations .sum (axis = 1 )
135
- return cam
136
-
137
- def compute_cam_per_layer (
138
- self , input_tensor : torch .Tensor , targets : List [torch .nn .Module ], eigen_smooth : bool
139
- ) -> np .ndarray :
140
- activations_list = [a for a in self .activations_and_grads .original_activations ]
141
- grads_list = [g for g in self .activations_and_grads .original_gradients ]
142
- target_size = self .get_target_width_height (input_tensor )
143
-
144
- cam_per_target_layer = []
145
- # Loop over the saliency image from every layer
146
- for i in range (len (self .target_layers )):
147
- target_layer = self .target_layers [i ]
148
- layer_activations = None
149
- layer_grads = None
150
- if i < len (activations_list ):
151
- layer_activations = activations_list [i ]
152
- if i < len (grads_list ):
153
- layer_grads = grads_list [i ]
154
-
155
- cam = self .get_cam_image (input_tensor , target_layer , targets , layer_activations , layer_grads , eigen_smooth )
156
- cam = np .maximum (cam , 0 )
157
- scaled = scale_cam_image (cam , target_size )
158
- cam_per_target_layer .append (scaled [:, None , :])
159
-
160
- return cam_per_target_layer
0 commit comments