diff --git a/README.md b/README.md index 4908a83b..eb28dcb8 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o | Deep Feature Factorizations | Non Negative Matrix Factorization on the 2D activations | | KPCA-CAM | Like EigenCAM but with Kernel PCA instead of PCA | | FEM | A gradient free method that binarizes activations by an activation > mean + k * std rule. | +| ShapleyCAM | Weight the activations using the gradient and Hessian-vector product.| ## Visual Examples | What makes the network think the image label is 'pug, pug-dog' | What makes the network think the image label is 'tabby, tabby cat' | Combining Grad-CAM with Guided Backpropagation for the 'pug, pug-dog' class | @@ -362,4 +363,8 @@ Sachin Karmani, Thanushon Sivakaran, Gaurav Prasad, Mehmet Ali, Wenbo Yang, Shey https://hal.science/hal-02963298/document
`Features Understanding in 3D CNNs for Actions Recognition in Video Kazi Ahmed Asif Fuad, Pierre-Etienne Martin, Romain Giot, Romain -Bourqui, Jenny Benois-Pineau, Akka Zemmar` \ No newline at end of file +Bourqui, Jenny Benois-Pineau, Akka Zemmar` + +https://arxiv.org/abs/2501.06261
+`CAMs as Shapley Value-based Explainers +Huaiguang Cai` diff --git a/cam.py b/cam.py index b14b7723..459b8ae2 100644 --- a/cam.py +++ b/cam.py @@ -7,13 +7,13 @@ from pytorch_grad_cam import ( GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, - LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM + LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM ) from pytorch_grad_cam import GuidedBackpropReLUModel from pytorch_grad_cam.utils.image import ( show_cam_on_image, deprocess_image, preprocess_image ) -from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, ClassifierOutputReST def get_args(): @@ -37,7 +37,7 @@ def get_args(): 'gradcam', 'fem', 'hirescam', 'gradcam++', 'scorecam', 'xgradcam', 'ablationcam', 'eigencam', 'eigengradcam', 'layercam', - 'fullgrad', 'gradcamelementwise', 'kpcacam' + 'fullgrad', 'gradcamelementwise', 'kpcacam', 'shapleycam' ], help='CAM method') @@ -75,7 +75,8 @@ def get_args(): "fullgrad": FullGrad, "fem": FEM, "gradcamelementwise": GradCAMElementWise, - 'kpcacam': KPCA_CAM + 'kpcacam': KPCA_CAM, + 'shapleycam': ShapleyCAM } if args.device=='hpu': @@ -109,7 +110,7 @@ def get_args(): # If targets is None, the highest scoring category (for every member in the batch) will be used. # You can target specific categories by # targets = [ClassifierOutputTarget(281)] - # targets = [ClassifierOutputTarget(281)] + # targets = [ClassifierOutputReST(281)] targets = None # Using the with statement ensures the context is freed, and you can diff --git a/pytorch_grad_cam/__init__.py b/pytorch_grad_cam/__init__.py index 7ac376a8..3b0d2f75 100644 --- a/pytorch_grad_cam/__init__.py +++ b/pytorch_grad_cam/__init__.py @@ -1,4 +1,5 @@ from pytorch_grad_cam.grad_cam import GradCAM +from pytorch_grad_cam.shapley_cam import ShapleyCAM from pytorch_grad_cam.fem import FEM from pytorch_grad_cam.hirescam import HiResCAM from pytorch_grad_cam.grad_cam_elementwise import GradCAMElementWise diff --git a/pytorch_grad_cam/activations_and_gradients.py b/pytorch_grad_cam/activations_and_gradients.py index 0c2071e5..8765567d 100644 --- a/pytorch_grad_cam/activations_and_gradients.py +++ b/pytorch_grad_cam/activations_and_gradients.py @@ -2,11 +2,12 @@ class ActivationsAndGradients: """ Class for extracting activations and registering gradients from targetted intermediate layers """ - def __init__(self, model, target_layers, reshape_transform): + def __init__(self, model, target_layers, reshape_transform, detach=True): self.model = model self.gradients = [] self.activations = [] self.reshape_transform = reshape_transform + self.detach = detach self.handles = [] for target_layer in target_layers: self.handles.append( @@ -18,10 +19,12 @@ def __init__(self, model, target_layers, reshape_transform): def save_activation(self, module, input, output): activation = output - - if self.reshape_transform is not None: - activation = self.reshape_transform(activation) - self.activations.append(activation.cpu().detach()) + if self.detach: + if self.reshape_transform is not None: + activation = self.reshape_transform(activation) + self.activations.append(activation.cpu().detach()) + else: + self.activations.append(activation) def save_gradient(self, module, input, output): if not hasattr(output, "requires_grad") or not output.requires_grad: @@ -30,9 +33,12 @@ def save_gradient(self, module, input, output): # Gradients are computed in reverse order def _store_grad(grad): - if self.reshape_transform is not None: - grad = self.reshape_transform(grad) - self.gradients = [grad.cpu().detach()] + self.gradients + if self.detach: + if self.reshape_transform is not None: + grad = self.reshape_transform(grad) + self.gradients = [grad.cpu().detach()] + self.gradients + else: + self.gradients = [grad] + self.gradients output.register_hook(_store_grad) diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py index 44ae5b90..484e8865 100644 --- a/pytorch_grad_cam/base_cam.py +++ b/pytorch_grad_cam/base_cam.py @@ -19,6 +19,7 @@ def __init__( compute_input_gradient: bool = False, uses_gradients: bool = True, tta_transforms: Optional[tta.Compose] = None, + detach: bool = True, ) -> None: self.model = model.eval() self.target_layers = target_layers @@ -45,7 +46,8 @@ def __init__( else: self.tta_transforms = tta_transforms - self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform) + self.detach = detach + self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform, self.detach) """ Get a vector of weights for every channel in the target layer. Methods that return weights channels, @@ -71,6 +73,8 @@ def get_cam_image( eigen_smooth: bool = False, ) -> np.ndarray: weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads) + if isinstance(activations, torch.Tensor): + activations = activations.cpu().detach().numpy() # 2D conv if len(activations.shape) == 4: weighted_activations = weights[:, :, None, None] * activations @@ -103,7 +107,13 @@ def forward( if self.uses_gradients: self.model.zero_grad() loss = sum([target(output) for target, output in zip(targets, outputs)]) - loss.backward(retain_graph=True) + if self.detach: + loss.backward(retain_graph=True) + else: + # keep the computational graph, create_graph = True is needed for hvp + torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True) + # When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle" + # loss.backward(retain_graph=True, create_graph=True) if 'hpu' in str(self.device): self.__htcore.mark_step() @@ -132,8 +142,12 @@ def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int] def compute_cam_per_layer( self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool ) -> np.ndarray: - activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations] - grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients] + if self.detach: + activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations] + grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients] + else: + activations_list = [a for a in self.activations_and_grads.activations] + grads_list = [g for g in self.activations_and_grads.gradients] target_size = self.get_target_width_height(input_tensor) cam_per_target_layer = [] diff --git a/pytorch_grad_cam/shapley_cam.py b/pytorch_grad_cam/shapley_cam.py new file mode 100644 index 00000000..e8331528 --- /dev/null +++ b/pytorch_grad_cam/shapley_cam.py @@ -0,0 +1,60 @@ +from typing import Callable, List, Optional, Tuple +from pytorch_grad_cam.base_cam import BaseCAM +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget +import torch +import numpy as np + +""" +Weights the activation maps using the gradient and Hessian-Vector product. +This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective. +""" +class ShapleyCAM(BaseCAM): + def __init__(self, model, target_layers, + reshape_transform=None): + super( + ShapleyCAM, + self).__init__( + model = model, + target_layers = target_layers, + reshape_transform = reshape_transform, + compute_input_gradient = True, + uses_gradients = True, + detach = False) + + def get_cam_weights(self, + input_tensor, + target_layer, + target_category, + activations, + grads): + + hvp = torch.autograd.grad( + outputs=grads, + inputs=activations, + grad_outputs=activations, + retain_graph=False, + allow_unused=True + )[0] + # print(torch.max(hvp[0]).item()) # check if hvp is not all zeros + if hvp is None: + hvp = torch.tensor(0).to(self.device) + else: + if self.activations_and_grads.reshape_transform is not None: + hvp = self.activations_and_grads.reshape_transform(hvp) + + if self.activations_and_grads.reshape_transform is not None: + activations = self.activations_and_grads.reshape_transform(activations) + grads = self.activations_and_grads.reshape_transform(grads) + + weight = (grads - 0.5 * hvp).detach().cpu().numpy() + # 2D image + if len(activations.shape) == 4: + weight = np.mean(weight, axis=(2, 3)) + return weight + # 3D image + elif len(activations.shape) == 5: + weight = np.mean(weight, axis=(2, 3, 4)) + return weight + else: + raise ValueError("Invalid grads shape." + "Shape of grads should be 4 (2D image) or 5 (3D image).") diff --git a/pytorch_grad_cam/utils/model_targets.py b/pytorch_grad_cam/utils/model_targets.py index d0a48189..4861ab9c 100644 --- a/pytorch_grad_cam/utils/model_targets.py +++ b/pytorch_grad_cam/utils/model_targets.py @@ -23,6 +23,22 @@ def __call__(self, model_output): return torch.softmax(model_output, dim=-1)[:, self.category] +class ClassifierOutputReST: + """ + Using both pre-softmax and post-softmax, proposed in https://arxiv.org/abs/2501.06261 + """ + def __init__(self, category): + self.category = category + def __call__(self, model_output): + if len(model_output.shape) == 1: + target = torch.tensor([self.category], device=model_output.device) + model_output = model_output.unsqueeze(0) + return model_output[0][self.category] - torch.nn.functional.cross_entropy(model_output, target) + else: + target = torch.tensor([self.category] * model_output.shape[0], device=model_output.device) + return model_output[:,self.category] - torch.nn.functional.cross_entropy(model_output, target) + + class BinaryClassifierOutputTarget: def __init__(self, category): self.category = category