Skip to content

Commit b0a4612

Browse files
ShapleyCAM
Weighting the activation maps using Gradient and Hessian-Vector Product.
1 parent b1cab2d commit b0a4612

File tree

6 files changed

+234
-3
lines changed

6 files changed

+234
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o
4747
| Deep Feature Factorizations | Non Negative Matrix Factorization on the 2D activations |
4848
| KPCA-CAM | Like EigenCAM but with Kernel PCA instead of PCA |
4949
| FEM | A gradient free method that binarizes activations by an activation > mean + k * std rule. |
50+
| ShapleyCAM | Weighting the activation maps using Gradient and Hessian-Vector Product.|
5051
## Visual Examples
5152

5253
| What makes the network think the image label is 'pug, pug-dog' | What makes the network think the image label is 'tabby, tabby cat' | Combining Grad-CAM with Guided Backpropagation for the 'pug, pug-dog' class |

cam.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytorch_grad_cam import (
88
GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
99
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
10-
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM
10+
LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM
1111
)
1212
from pytorch_grad_cam import GuidedBackpropReLUModel
1313
from pytorch_grad_cam.utils.image import (
@@ -37,7 +37,7 @@ def get_args():
3737
'gradcam', 'fem', 'hirescam', 'gradcam++',
3838
'scorecam', 'xgradcam', 'ablationcam',
3939
'eigencam', 'eigengradcam', 'layercam',
40-
'fullgrad', 'gradcamelementwise', 'kpcacam'
40+
'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam'
4141
],
4242
help='CAM method')
4343

@@ -75,7 +75,8 @@ def get_args():
7575
"fullgrad": FullGrad,
7676
"fem": FEM,
7777
"gradcamelementwise": GradCAMElementWise,
78-
'kpcacam': KPCA_CAM
78+
'kpcacam': KPCA_CAM,
79+
'shapleycam': ShapleyCAM
7980
}
8081

8182
if args.device=='hpu':

pytorch_grad_cam/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pytorch_grad_cam.grad_cam import GradCAM
2+
from pytorch_grad_cam.shapley_cam import ShapleyCAM
23
from pytorch_grad_cam.fem import FEM
34
from pytorch_grad_cam.hirescam import HiResCAM
45
from pytorch_grad_cam.grad_cam_elementwise import GradCAMElementWise
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
class ActivationsAndGradients_no_detach:
2+
""" Class for extracting activations and
3+
registering gradients from targetted intermediate layers """
4+
5+
def __init__(self, model, target_layers, reshape_transform):
6+
self.model = model
7+
# self.gradients = []
8+
# self.activations = []
9+
self.original_gradients = []
10+
self.original_activations = []
11+
self.reshape_transform = reshape_transform
12+
self.handles = []
13+
for target_layer in target_layers:
14+
self.handles.append(
15+
target_layer.register_forward_hook(self.save_activation))
16+
# Because of https://github.com/pytorch/pytorch/issues/61519,
17+
# we don't use backward hook to record gradients.
18+
self.handles.append(
19+
target_layer.register_forward_hook(self.save_gradient))
20+
21+
def save_activation(self, module, input, output):
22+
activation = output
23+
24+
self.original_activations.append(activation)
25+
# if self.reshape_transform is not None:
26+
# activation = self.reshape_transform(activation)
27+
# # self.activations.append(activation.cpu().detach())
28+
# self.activations.append(activation)
29+
30+
def save_gradient(self, module, input, output):
31+
if not hasattr(output, "requires_grad") or not output.requires_grad:
32+
# You can only register hooks on tensor requires grad.
33+
return
34+
35+
# Gradients are computed in reverse order
36+
def _store_grad(grad):
37+
self.original_gradients = [grad] + self.original_gradients
38+
# if self.reshape_transform is not None:
39+
# grad = self.reshape_transform(grad)
40+
# self.gradients = [grad] + self.gradients
41+
42+
output.register_hook(_store_grad)
43+
44+
def __call__(self, x):
45+
# self.gradients = []
46+
# self.activations = []
47+
self.original_gradients = []
48+
self.original_activations = []
49+
return self.model(x)
50+
51+
def release(self):
52+
for handle in self.handles:
53+
handle.remove()

pytorch_grad_cam/shapley_cam.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from typing import Callable, List, Optional, Tuple
2+
3+
import numpy as np
4+
import torch
5+
from pytorch_grad_cam.base_cam import BaseCAM
6+
from scipy.signal import convolve2d
7+
from scipy.ndimage import gaussian_filter
8+
import cv2
9+
10+
from pytorch_grad_cam.activations_and_gradients_no_detach import ActivationsAndGradients_no_detach
11+
from pytorch_grad_cam.utils.image import scale_cam_image
12+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13+
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
14+
15+
"""
16+
Weighting the activation maps using Gradient and Hessian-Vector Product.
17+
This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods from a Shapley value perspective.
18+
"""
19+
class ShapleyCAM(BaseCAM):
20+
def __init__(self, model, target_layers,
21+
reshape_transform=None):
22+
super(
23+
ShapleyCAM,
24+
self).__init__(
25+
model,
26+
target_layers,
27+
reshape_transform)
28+
29+
self.activations_and_grads = ActivationsAndGradients_no_detach(self.model, target_layers, reshape_transform)
30+
31+
def forward(
32+
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
33+
) -> np.ndarray:
34+
input_tensor = input_tensor.to(self.device)
35+
36+
input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
37+
38+
self.outputs = outputs = self.activations_and_grads(input_tensor)
39+
40+
if targets is None:
41+
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
42+
targets = [ClassifierOutputTarget(category) for category in target_categories]
43+
44+
if self.uses_gradients:
45+
self.model.zero_grad()
46+
loss = sum([target(output) for target, output in zip(targets, outputs)])
47+
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
48+
49+
# In most of the saliency attribution papers, the saliency is
50+
# computed with a single target layer.
51+
# Commonly it is the last convolutional layer.
52+
# Here we support passing a list with multiple target layers.
53+
# It will compute the saliency image for every image,
54+
# and then aggregate them (with a default mean aggregation).
55+
# This gives you more flexibility in case you just want to
56+
# use all conv layers for example, all Batchnorm layers,
57+
# or something else.
58+
cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
59+
return self.aggregate_multi_layers(cam_per_layer)
60+
61+
62+
def get_cam_weights(self,
63+
input_tensor,
64+
target_layer,
65+
target_category,
66+
activations,
67+
grads):
68+
activations: List[Tensor] # type: ignore[assignment]
69+
grads: List[Tensor] # type: ignore[assignment]
70+
71+
hvp = torch.autograd.grad(
72+
outputs=grads,
73+
inputs=activations,
74+
grad_outputs=activations,
75+
retain_graph=False,
76+
allow_unused=True
77+
)[0]
78+
if hvp is None:
79+
hvp = torch.tensor(0).to(self.device)
80+
elif self.activations_and_grads.reshape_transform is not None:
81+
hvp = self.activations_and_grads.reshape_transform(hvp)
82+
83+
if self.activations_and_grads.reshape_transform is not None:
84+
activations = self.activations_and_grads.reshape_transform(activations)
85+
grads = self.activations_and_grads.reshape_transform(grads)
86+
weight = (grads - 0.5*hvp).cpu().detach().numpy()
87+
activations = activations.cpu().detach().numpy()
88+
grads = grads.cpu().detach().numpy()
89+
90+
91+
# 2D image
92+
if len(activations.shape) == 4:
93+
weight = np.mean(weight, axis=(2, 3))
94+
return weight, activations
95+
96+
# 3D image
97+
elif len(activations.shape) == 5:
98+
weight = np.mean(weight, axis=(2, 3, 4))
99+
return weight, activations
100+
101+
else:
102+
raise ValueError("Invalid grads shape."
103+
"Shape of grads should be 4 (2D image) or 5 (3D image).")
104+
105+
106+
107+
def get_cam_image(
108+
self,
109+
input_tensor: torch.Tensor,
110+
target_layer: torch.nn.Module,
111+
targets: List[torch.nn.Module],
112+
activations: torch.Tensor,
113+
grads: torch.Tensor,
114+
eigen_smooth: bool = False,
115+
) -> np.ndarray:
116+
weights, activations = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
117+
118+
# 2D conv
119+
if len(activations.shape) == 4:
120+
weighted_activations = weights[:, :, None, None] * activations
121+
122+
# 3D conv
123+
elif len(activations.shape) == 5:
124+
weighted_activations = weights[:, :, None, None, None] * activations
125+
else:
126+
raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.")
127+
128+
# weighted_activations = np.maximum(weighted_activations, 0)
129+
# weighted_activations = np.abs(weighted_activations)
130+
if eigen_smooth:
131+
cam = get_2d_projection(weighted_activations)
132+
else:
133+
cam = weighted_activations.sum(axis=1)
134+
return cam
135+
136+
def compute_cam_per_layer(
137+
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
138+
) -> np.ndarray:
139+
activations_list = [a for a in self.activations_and_grads.original_activations]
140+
grads_list = [g for g in self.activations_and_grads.original_gradients]
141+
target_size = self.get_target_width_height(input_tensor)
142+
143+
cam_per_target_layer = []
144+
# Loop over the saliency image from every layer
145+
for i in range(len(self.target_layers)):
146+
target_layer = self.target_layers[i]
147+
layer_activations = None
148+
layer_grads = None
149+
if i < len(activations_list):
150+
layer_activations = activations_list[i]
151+
if i < len(grads_list):
152+
layer_grads = grads_list[i]
153+
154+
cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth)
155+
cam = np.maximum(cam, 0)
156+
scaled = scale_cam_image(cam, target_size)
157+
cam_per_target_layer.append(scaled[:, None, :])
158+
159+
return cam_per_target_layer

pytorch_grad_cam/utils/model_targets.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,22 @@ def __call__(self, model_output):
2323
return torch.softmax(model_output, dim=-1)[:, self.category]
2424

2525

26+
class ClassifierOutputReST:
27+
"""
28+
Using both pre-softmax and post-softmax, propoesed in https://arxiv.org/abs/2501.06261
29+
"""
30+
def __init__(self, category):
31+
self.category = category
32+
def __call__(self, model_output):
33+
if len(model_output.shape) == 1:
34+
target = torch.tensor([self.category], device=model_output.device)
35+
model_output = model_output.unsqueeze(0)
36+
return model_output[0][self.category] - torch.nn.functional.cross_entropy(model_output, target)
37+
else:
38+
target = torch.tensor([self.category] * model_output.shape[0], device=model_output.device)
39+
return model_output[:,self.category]- torch.nn.functional.cross_entropy(model_output, target)
40+
41+
2642
class BinaryClassifierOutputTarget:
2743
def __init__(self, category):
2844
self.category = category

0 commit comments

Comments
 (0)