diff --git a/README.md b/README.md
index 4908a83b..eb28dcb8 100644
--- a/README.md
+++ b/README.md
@@ -47,6 +47,7 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o
| Deep Feature Factorizations | Non Negative Matrix Factorization on the 2D activations |
| KPCA-CAM | Like EigenCAM but with Kernel PCA instead of PCA |
| FEM | A gradient free method that binarizes activations by an activation > mean + k * std rule. |
+| ShapleyCAM | Weight the activations using the gradient and Hessian-vector product.|
## Visual Examples
| What makes the network think the image label is 'pug, pug-dog' | What makes the network think the image label is 'tabby, tabby cat' | Combining Grad-CAM with Guided Backpropagation for the 'pug, pug-dog' class |
@@ -362,4 +363,8 @@ Sachin Karmani, Thanushon Sivakaran, Gaurav Prasad, Mehmet Ali, Wenbo Yang, Shey
https://hal.science/hal-02963298/document
`Features Understanding in 3D CNNs for Actions Recognition in Video
Kazi Ahmed Asif Fuad, Pierre-Etienne Martin, Romain Giot, Romain
-Bourqui, Jenny Benois-Pineau, Akka Zemmar`
\ No newline at end of file
+Bourqui, Jenny Benois-Pineau, Akka Zemmar`
+
+https://arxiv.org/abs/2501.06261
+`CAMs as Shapley Value-based Explainers
+Huaiguang Cai`
diff --git a/cam.py b/cam.py
index b14b7723..459b8ae2 100644
--- a/cam.py
+++ b/cam.py
@@ -7,13 +7,13 @@
from pytorch_grad_cam import (
GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
- LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM
+ LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM
)
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import (
show_cam_on_image, deprocess_image, preprocess_image
)
-from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, ClassifierOutputReST
def get_args():
@@ -37,7 +37,7 @@ def get_args():
'gradcam', 'fem', 'hirescam', 'gradcam++',
'scorecam', 'xgradcam', 'ablationcam',
'eigencam', 'eigengradcam', 'layercam',
- 'fullgrad', 'gradcamelementwise', 'kpcacam'
+ 'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam'
],
help='CAM method')
@@ -75,7 +75,8 @@ def get_args():
"fullgrad": FullGrad,
"fem": FEM,
"gradcamelementwise": GradCAMElementWise,
- 'kpcacam': KPCA_CAM
+ 'kpcacam': KPCA_CAM,
+ 'shapleycam': ShapleyCAM
}
if args.device=='hpu':
@@ -109,7 +110,7 @@ def get_args():
# If targets is None, the highest scoring category (for every member in the batch) will be used.
# You can target specific categories by
# targets = [ClassifierOutputTarget(281)]
- # targets = [ClassifierOutputTarget(281)]
+ # targets = [ClassifierOutputReST(281)]
targets = None
# Using the with statement ensures the context is freed, and you can
diff --git a/pytorch_grad_cam/__init__.py b/pytorch_grad_cam/__init__.py
index 7ac376a8..3b0d2f75 100644
--- a/pytorch_grad_cam/__init__.py
+++ b/pytorch_grad_cam/__init__.py
@@ -1,4 +1,5 @@
from pytorch_grad_cam.grad_cam import GradCAM
+from pytorch_grad_cam.shapley_cam import ShapleyCAM
from pytorch_grad_cam.fem import FEM
from pytorch_grad_cam.hirescam import HiResCAM
from pytorch_grad_cam.grad_cam_elementwise import GradCAMElementWise
diff --git a/pytorch_grad_cam/activations_and_gradients.py b/pytorch_grad_cam/activations_and_gradients.py
index 0c2071e5..8765567d 100644
--- a/pytorch_grad_cam/activations_and_gradients.py
+++ b/pytorch_grad_cam/activations_and_gradients.py
@@ -2,11 +2,12 @@ class ActivationsAndGradients:
""" Class for extracting activations and
registering gradients from targetted intermediate layers """
- def __init__(self, model, target_layers, reshape_transform):
+ def __init__(self, model, target_layers, reshape_transform, detach=True):
self.model = model
self.gradients = []
self.activations = []
self.reshape_transform = reshape_transform
+ self.detach = detach
self.handles = []
for target_layer in target_layers:
self.handles.append(
@@ -18,10 +19,12 @@ def __init__(self, model, target_layers, reshape_transform):
def save_activation(self, module, input, output):
activation = output
-
- if self.reshape_transform is not None:
- activation = self.reshape_transform(activation)
- self.activations.append(activation.cpu().detach())
+ if self.detach:
+ if self.reshape_transform is not None:
+ activation = self.reshape_transform(activation)
+ self.activations.append(activation.cpu().detach())
+ else:
+ self.activations.append(activation)
def save_gradient(self, module, input, output):
if not hasattr(output, "requires_grad") or not output.requires_grad:
@@ -30,9 +33,12 @@ def save_gradient(self, module, input, output):
# Gradients are computed in reverse order
def _store_grad(grad):
- if self.reshape_transform is not None:
- grad = self.reshape_transform(grad)
- self.gradients = [grad.cpu().detach()] + self.gradients
+ if self.detach:
+ if self.reshape_transform is not None:
+ grad = self.reshape_transform(grad)
+ self.gradients = [grad.cpu().detach()] + self.gradients
+ else:
+ self.gradients = [grad] + self.gradients
output.register_hook(_store_grad)
diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py
index 44ae5b90..484e8865 100644
--- a/pytorch_grad_cam/base_cam.py
+++ b/pytorch_grad_cam/base_cam.py
@@ -19,6 +19,7 @@ def __init__(
compute_input_gradient: bool = False,
uses_gradients: bool = True,
tta_transforms: Optional[tta.Compose] = None,
+ detach: bool = True,
) -> None:
self.model = model.eval()
self.target_layers = target_layers
@@ -45,7 +46,8 @@ def __init__(
else:
self.tta_transforms = tta_transforms
- self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)
+ self.detach = detach
+ self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform, self.detach)
""" Get a vector of weights for every channel in the target layer.
Methods that return weights channels,
@@ -71,6 +73,8 @@ def get_cam_image(
eigen_smooth: bool = False,
) -> np.ndarray:
weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
+ if isinstance(activations, torch.Tensor):
+ activations = activations.cpu().detach().numpy()
# 2D conv
if len(activations.shape) == 4:
weighted_activations = weights[:, :, None, None] * activations
@@ -103,7 +107,13 @@ def forward(
if self.uses_gradients:
self.model.zero_grad()
loss = sum([target(output) for target, output in zip(targets, outputs)])
- loss.backward(retain_graph=True)
+ if self.detach:
+ loss.backward(retain_graph=True)
+ else:
+ # keep the computational graph, create_graph = True is needed for hvp
+ torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
+ # When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
+ # loss.backward(retain_graph=True, create_graph=True)
if 'hpu' in str(self.device):
self.__htcore.mark_step()
@@ -132,8 +142,12 @@ def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]
def compute_cam_per_layer(
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
) -> np.ndarray:
- activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
- grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
+ if self.detach:
+ activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
+ grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
+ else:
+ activations_list = [a for a in self.activations_and_grads.activations]
+ grads_list = [g for g in self.activations_and_grads.gradients]
target_size = self.get_target_width_height(input_tensor)
cam_per_target_layer = []
diff --git a/pytorch_grad_cam/shapley_cam.py b/pytorch_grad_cam/shapley_cam.py
new file mode 100644
index 00000000..e8331528
--- /dev/null
+++ b/pytorch_grad_cam/shapley_cam.py
@@ -0,0 +1,60 @@
+from typing import Callable, List, Optional, Tuple
+from pytorch_grad_cam.base_cam import BaseCAM
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+import torch
+import numpy as np
+
+"""
+Weights the activation maps using the gradient and Hessian-Vector product.
+This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective.
+"""
+class ShapleyCAM(BaseCAM):
+ def __init__(self, model, target_layers,
+ reshape_transform=None):
+ super(
+ ShapleyCAM,
+ self).__init__(
+ model = model,
+ target_layers = target_layers,
+ reshape_transform = reshape_transform,
+ compute_input_gradient = True,
+ uses_gradients = True,
+ detach = False)
+
+ def get_cam_weights(self,
+ input_tensor,
+ target_layer,
+ target_category,
+ activations,
+ grads):
+
+ hvp = torch.autograd.grad(
+ outputs=grads,
+ inputs=activations,
+ grad_outputs=activations,
+ retain_graph=False,
+ allow_unused=True
+ )[0]
+ # print(torch.max(hvp[0]).item()) # check if hvp is not all zeros
+ if hvp is None:
+ hvp = torch.tensor(0).to(self.device)
+ else:
+ if self.activations_and_grads.reshape_transform is not None:
+ hvp = self.activations_and_grads.reshape_transform(hvp)
+
+ if self.activations_and_grads.reshape_transform is not None:
+ activations = self.activations_and_grads.reshape_transform(activations)
+ grads = self.activations_and_grads.reshape_transform(grads)
+
+ weight = (grads - 0.5 * hvp).detach().cpu().numpy()
+ # 2D image
+ if len(activations.shape) == 4:
+ weight = np.mean(weight, axis=(2, 3))
+ return weight
+ # 3D image
+ elif len(activations.shape) == 5:
+ weight = np.mean(weight, axis=(2, 3, 4))
+ return weight
+ else:
+ raise ValueError("Invalid grads shape."
+ "Shape of grads should be 4 (2D image) or 5 (3D image).")
diff --git a/pytorch_grad_cam/utils/model_targets.py b/pytorch_grad_cam/utils/model_targets.py
index d0a48189..4861ab9c 100644
--- a/pytorch_grad_cam/utils/model_targets.py
+++ b/pytorch_grad_cam/utils/model_targets.py
@@ -23,6 +23,22 @@ def __call__(self, model_output):
return torch.softmax(model_output, dim=-1)[:, self.category]
+class ClassifierOutputReST:
+ """
+ Using both pre-softmax and post-softmax, proposed in https://arxiv.org/abs/2501.06261
+ """
+ def __init__(self, category):
+ self.category = category
+ def __call__(self, model_output):
+ if len(model_output.shape) == 1:
+ target = torch.tensor([self.category], device=model_output.device)
+ model_output = model_output.unsqueeze(0)
+ return model_output[0][self.category] - torch.nn.functional.cross_entropy(model_output, target)
+ else:
+ target = torch.tensor([self.category] * model_output.shape[0], device=model_output.device)
+ return model_output[:,self.category] - torch.nn.functional.cross_entropy(model_output, target)
+
+
class BinaryClassifierOutputTarget:
def __init__(self, category):
self.category = category