|
| 1 | +from typing import Callable, List, Optional, Tuple |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +from pytorch_grad_cam.base_cam import BaseCAM |
| 6 | +from scipy.signal import convolve2d |
| 7 | +from scipy.ndimage import gaussian_filter |
| 8 | +import cv2 |
| 9 | + |
| 10 | +from pytorch_grad_cam.activations_and_gradients_no_detach import ActivationsAndGradients_no_detach |
| 11 | +from pytorch_grad_cam.utils.image import scale_cam_image |
| 12 | +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
| 13 | +from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection |
| 14 | + |
| 15 | +""" |
| 16 | +Weighting the activation maps using Gradient and Hessian-Vector Product. |
| 17 | +This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods from a Shapley value perspective. |
| 18 | +""" |
| 19 | +class ShapleyCAM(BaseCAM): |
| 20 | + def __init__(self, model, target_layers, |
| 21 | + reshape_transform=None): |
| 22 | + super( |
| 23 | + ShapleyCAM, |
| 24 | + self).__init__( |
| 25 | + model, |
| 26 | + target_layers, |
| 27 | + reshape_transform) |
| 28 | + |
| 29 | + self.activations_and_grads = ActivationsAndGradients_no_detach(self.model, target_layers, reshape_transform) |
| 30 | + |
| 31 | + def forward( |
| 32 | + self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False |
| 33 | + ) -> np.ndarray: |
| 34 | + input_tensor = input_tensor.to(self.device) |
| 35 | + |
| 36 | + input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True) |
| 37 | + |
| 38 | + self.outputs = outputs = self.activations_and_grads(input_tensor) |
| 39 | + |
| 40 | + if targets is None: |
| 41 | + target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1) |
| 42 | + targets = [ClassifierOutputTarget(category) for category in target_categories] |
| 43 | + |
| 44 | + if self.uses_gradients: |
| 45 | + self.model.zero_grad() |
| 46 | + loss = sum([target(output) for target, output in zip(targets, outputs)]) |
| 47 | + torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True) |
| 48 | + |
| 49 | + # In most of the saliency attribution papers, the saliency is |
| 50 | + # computed with a single target layer. |
| 51 | + # Commonly it is the last convolutional layer. |
| 52 | + # Here we support passing a list with multiple target layers. |
| 53 | + # It will compute the saliency image for every image, |
| 54 | + # and then aggregate them (with a default mean aggregation). |
| 55 | + # This gives you more flexibility in case you just want to |
| 56 | + # use all conv layers for example, all Batchnorm layers, |
| 57 | + # or something else. |
| 58 | + cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth) |
| 59 | + return self.aggregate_multi_layers(cam_per_layer) |
| 60 | + |
| 61 | + |
| 62 | + def get_cam_weights(self, |
| 63 | + input_tensor, |
| 64 | + target_layer, |
| 65 | + target_category, |
| 66 | + activations, |
| 67 | + grads): |
| 68 | + activations: List[Tensor] # type: ignore[assignment] |
| 69 | + grads: List[Tensor] # type: ignore[assignment] |
| 70 | + |
| 71 | + hvp = torch.autograd.grad( |
| 72 | + outputs=grads, |
| 73 | + inputs=activations, |
| 74 | + grad_outputs=activations, |
| 75 | + retain_graph=False, |
| 76 | + allow_unused=True |
| 77 | + )[0] |
| 78 | + if hvp is None: |
| 79 | + hvp = torch.tensor(0).to(self.device) |
| 80 | + elif self.activations_and_grads.reshape_transform is not None: |
| 81 | + hvp = self.activations_and_grads.reshape_transform(hvp) |
| 82 | + |
| 83 | + if self.activations_and_grads.reshape_transform is not None: |
| 84 | + activations = self.activations_and_grads.reshape_transform(activations) |
| 85 | + grads = self.activations_and_grads.reshape_transform(grads) |
| 86 | + weight = (grads - 0.5*hvp).cpu().detach().numpy() |
| 87 | + activations = activations.cpu().detach().numpy() |
| 88 | + grads = grads.cpu().detach().numpy() |
| 89 | + |
| 90 | + |
| 91 | + # 2D image |
| 92 | + if len(activations.shape) == 4: |
| 93 | + weight = np.mean(weight, axis=(2, 3)) |
| 94 | + return weight, activations |
| 95 | + |
| 96 | + # 3D image |
| 97 | + elif len(activations.shape) == 5: |
| 98 | + weight = np.mean(weight, axis=(2, 3, 4)) |
| 99 | + return weight, activations |
| 100 | + |
| 101 | + else: |
| 102 | + raise ValueError("Invalid grads shape." |
| 103 | + "Shape of grads should be 4 (2D image) or 5 (3D image).") |
| 104 | + |
| 105 | + |
| 106 | + |
| 107 | + def get_cam_image( |
| 108 | + self, |
| 109 | + input_tensor: torch.Tensor, |
| 110 | + target_layer: torch.nn.Module, |
| 111 | + targets: List[torch.nn.Module], |
| 112 | + activations: torch.Tensor, |
| 113 | + grads: torch.Tensor, |
| 114 | + eigen_smooth: bool = False, |
| 115 | + ) -> np.ndarray: |
| 116 | + weights, activations = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads) |
| 117 | + |
| 118 | + # 2D conv |
| 119 | + if len(activations.shape) == 4: |
| 120 | + weighted_activations = weights[:, :, None, None] * activations |
| 121 | + |
| 122 | + # 3D conv |
| 123 | + elif len(activations.shape) == 5: |
| 124 | + weighted_activations = weights[:, :, None, None, None] * activations |
| 125 | + else: |
| 126 | + raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.") |
| 127 | + |
| 128 | + # weighted_activations = np.maximum(weighted_activations, 0) |
| 129 | + # weighted_activations = np.abs(weighted_activations) |
| 130 | + if eigen_smooth: |
| 131 | + cam = get_2d_projection(weighted_activations) |
| 132 | + else: |
| 133 | + cam = weighted_activations.sum(axis=1) |
| 134 | + return cam |
| 135 | + |
| 136 | + def compute_cam_per_layer( |
| 137 | + self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool |
| 138 | + ) -> np.ndarray: |
| 139 | + activations_list = [a for a in self.activations_and_grads.original_activations] |
| 140 | + grads_list = [g for g in self.activations_and_grads.original_gradients] |
| 141 | + target_size = self.get_target_width_height(input_tensor) |
| 142 | + |
| 143 | + cam_per_target_layer = [] |
| 144 | + # Loop over the saliency image from every layer |
| 145 | + for i in range(len(self.target_layers)): |
| 146 | + target_layer = self.target_layers[i] |
| 147 | + layer_activations = None |
| 148 | + layer_grads = None |
| 149 | + if i < len(activations_list): |
| 150 | + layer_activations = activations_list[i] |
| 151 | + if i < len(grads_list): |
| 152 | + layer_grads = grads_list[i] |
| 153 | + |
| 154 | + cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth) |
| 155 | + cam = np.maximum(cam, 0) |
| 156 | + scaled = scale_cam_image(cam, target_size) |
| 157 | + cam_per_target_layer.append(scaled[:, None, :]) |
| 158 | + |
| 159 | + return cam_per_target_layer |
0 commit comments