Skip to content

Commit 4a8d358

Browse files
update a simpler version
1 parent ddf1618 commit 4a8d358

File tree

5 files changed

+42
-150
lines changed

5 files changed

+42
-150
lines changed

pytorch_grad_cam/activations_and_gradients.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ class ActivationsAndGradients:
22
""" Class for extracting activations and
33
registering gradients from targetted intermediate layers """
44

5-
def __init__(self, model, target_layers, reshape_transform):
5+
def __init__(self, model, target_layers, reshape_transform, detach=True):
66
self.model = model
77
self.gradients = []
88
self.activations = []
99
self.reshape_transform = reshape_transform
10+
self.detach = detach
1011
self.handles = []
1112
for target_layer in target_layers:
1213
self.handles.append(
@@ -18,10 +19,12 @@ def __init__(self, model, target_layers, reshape_transform):
1819

1920
def save_activation(self, module, input, output):
2021
activation = output
21-
22-
if self.reshape_transform is not None:
23-
activation = self.reshape_transform(activation)
24-
self.activations.append(activation.cpu().detach())
22+
if self.detach:
23+
if self.reshape_transform is not None:
24+
activation = self.reshape_transform(activation)
25+
self.activations.append(activation.cpu().detach())
26+
else:
27+
self.activations.append(activation)
2528

2629
def save_gradient(self, module, input, output):
2730
if not hasattr(output, "requires_grad") or not output.requires_grad:
@@ -30,9 +33,12 @@ def save_gradient(self, module, input, output):
3033

3134
# Gradients are computed in reverse order
3235
def _store_grad(grad):
33-
if self.reshape_transform is not None:
34-
grad = self.reshape_transform(grad)
35-
self.gradients = [grad.cpu().detach()] + self.gradients
36+
if self.detach:
37+
if self.reshape_transform is not None:
38+
grad = self.reshape_transform(grad)
39+
self.gradients = [grad.cpu().detach()] + self.gradients
40+
else:
41+
self.gradients = [grad] + self.gradients
3642

3743
output.register_hook(_store_grad)
3844

pytorch_grad_cam/activations_and_gradients_no_detach.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

pytorch_grad_cam/base_cam.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
compute_input_gradient: bool = False,
2020
uses_gradients: bool = True,
2121
tta_transforms: Optional[tta.Compose] = None,
22+
detach: bool = True,
2223
) -> None:
2324
self.model = model.eval()
2425
self.target_layers = target_layers
@@ -45,7 +46,8 @@ def __init__(
4546
else:
4647
self.tta_transforms = tta_transforms
4748

48-
self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)
49+
self.detach = detach
50+
self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform, self.detach)
4951

5052
""" Get a vector of weights for every channel in the target layer.
5153
Methods that return weights channels,
@@ -71,6 +73,8 @@ def get_cam_image(
7173
eigen_smooth: bool = False,
7274
) -> np.ndarray:
7375
weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
76+
if isinstance(activations, torch.Tensor):
77+
activations = activations.cpu().detach().numpy()
7478
# 2D conv
7579
if len(activations.shape) == 4:
7680
weighted_activations = weights[:, :, None, None] * activations
@@ -132,8 +136,12 @@ def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]
132136
def compute_cam_per_layer(
133137
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
134138
) -> np.ndarray:
135-
activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
136-
grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
139+
if self.detach:
140+
activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
141+
grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
142+
else:
143+
activations_list = [a for a in self.activations_and_grads.activations]
144+
grads_list = [g for g in self.activations_and_grads.gradients]
137145
target_size = self.get_target_width_height(input_tensor)
138146

139147
cam_per_target_layer = []

pytorch_grad_cam/shapley_cam.py

Lines changed: 15 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,22 @@
11
from typing import Callable, List, Optional, Tuple
2-
3-
import numpy as np
4-
import torch
52
from pytorch_grad_cam.base_cam import BaseCAM
6-
from scipy.signal import convolve2d
7-
from scipy.ndimage import gaussian_filter
8-
import cv2
9-
10-
from pytorch_grad_cam.activations_and_gradients_no_detach import ActivationsAndGradients_no_detach
11-
from pytorch_grad_cam.utils.image import scale_cam_image
12-
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13-
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
3+
import torch
4+
import numpy as np
145

156
"""
167
Weighting the activation maps using Gradient and Hessian-Vector Product.
17-
This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods from a Shapley value perspective.
8+
This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective.
189
"""
1910
class ShapleyCAM(BaseCAM):
2011
def __init__(self, model, target_layers,
21-
reshape_transform=None):
12+
reshape_transform=None, detach=False):
2213
super(
2314
ShapleyCAM,
2415
self).__init__(
2516
model,
2617
target_layers,
27-
reshape_transform)
28-
29-
self.activations_and_grads = ActivationsAndGradients_no_detach(self.model, target_layers, reshape_transform)
18+
reshape_transform,
19+
detach = detach)
3020

3121
def forward(
3222
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
@@ -44,6 +34,7 @@ def forward(
4434
if self.uses_gradients:
4535
self.model.zero_grad()
4636
loss = sum([target(output) for target, output in zip(targets, outputs)])
37+
# keep the graph
4738
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
4839

4940
# In most of the saliency attribution papers, the saliency is
@@ -65,96 +56,36 @@ def get_cam_weights(self,
6556
target_category,
6657
activations,
6758
grads):
68-
activations: List[Tensor] # type: ignore[assignment]
69-
grads: List[Tensor] # type: ignore[assignment]
70-
59+
7160
hvp = torch.autograd.grad(
7261
outputs=grads,
7362
inputs=activations,
7463
grad_outputs=activations,
7564
retain_graph=False,
7665
allow_unused=True
7766
)[0]
78-
# print(torch.max(hvp[0]).item()) # verify that hvp is not all zeros
67+
# print(torch.max(hvp[0]).item()) # Use .item() to get the scalar value
7968
if hvp is None:
8069
hvp = torch.tensor(0).to(self.device)
81-
elif self.activations_and_grads.reshape_transform is not None:
82-
hvp = self.activations_and_grads.reshape_transform(hvp)
70+
else:
71+
if self.activations_and_grads.reshape_transform is not None:
72+
hvp = self.activations_and_grads.reshape_transform(hvp)
8373

8474
if self.activations_and_grads.reshape_transform is not None:
8575
activations = self.activations_and_grads.reshape_transform(activations)
8676
grads = self.activations_and_grads.reshape_transform(grads)
87-
weight = (grads - 0.5 * hvp).cpu().detach().numpy()
88-
activations = activations.cpu().detach().numpy()
89-
grads = grads.cpu().detach().numpy()
90-
9177

78+
weight = (grads - 0.5 * hvp).detach().cpu().numpy()
9279
# 2D image
9380
if len(activations.shape) == 4:
9481
weight = np.mean(weight, axis=(2, 3))
95-
return weight, activations
82+
return weight
9683

9784
# 3D image
9885
elif len(activations.shape) == 5:
9986
weight = np.mean(weight, axis=(2, 3, 4))
100-
return weight, activations
87+
return weight
10188

10289
else:
10390
raise ValueError("Invalid grads shape."
10491
"Shape of grads should be 4 (2D image) or 5 (3D image).")
105-
106-
107-
108-
def get_cam_image(
109-
self,
110-
input_tensor: torch.Tensor,
111-
target_layer: torch.nn.Module,
112-
targets: List[torch.nn.Module],
113-
activations: torch.Tensor,
114-
grads: torch.Tensor,
115-
eigen_smooth: bool = False,
116-
) -> np.ndarray:
117-
weights, activations = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
118-
119-
# 2D conv
120-
if len(activations.shape) == 4:
121-
weighted_activations = weights[:, :, None, None] * activations
122-
123-
# 3D conv
124-
elif len(activations.shape) == 5:
125-
weighted_activations = weights[:, :, None, None, None] * activations
126-
else:
127-
raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.")
128-
129-
# weighted_activations = np.maximum(weighted_activations, 0)
130-
# weighted_activations = np.abs(weighted_activations)
131-
if eigen_smooth:
132-
cam = get_2d_projection(weighted_activations)
133-
else:
134-
cam = weighted_activations.sum(axis=1)
135-
return cam
136-
137-
def compute_cam_per_layer(
138-
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
139-
) -> np.ndarray:
140-
activations_list = [a for a in self.activations_and_grads.original_activations]
141-
grads_list = [g for g in self.activations_and_grads.original_gradients]
142-
target_size = self.get_target_width_height(input_tensor)
143-
144-
cam_per_target_layer = []
145-
# Loop over the saliency image from every layer
146-
for i in range(len(self.target_layers)):
147-
target_layer = self.target_layers[i]
148-
layer_activations = None
149-
layer_grads = None
150-
if i < len(activations_list):
151-
layer_activations = activations_list[i]
152-
if i < len(grads_list):
153-
layer_grads = grads_list[i]
154-
155-
cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth)
156-
cam = np.maximum(cam, 0)
157-
scaled = scale_cam_image(cam, target_size)
158-
cam_per_target_layer.append(scaled[:, None, :])
159-
160-
return cam_per_target_layer

pytorch_grad_cam/utils/model_targets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __call__(self, model_output):
2525

2626
class ClassifierOutputReST:
2727
"""
28-
Using both pre-softmax and post-softmax, propoesed in https://arxiv.org/abs/2501.06261
28+
Using both pre-softmax and post-softmax, proposed in https://arxiv.org/abs/2501.06261
2929
"""
3030
def __init__(self, category):
3131
self.category = category
@@ -36,7 +36,7 @@ def __call__(self, model_output):
3636
return model_output[0][self.category] - torch.nn.functional.cross_entropy(model_output, target)
3737
else:
3838
target = torch.tensor([self.category] * model_output.shape[0], device=model_output.device)
39-
return model_output[:,self.category]- torch.nn.functional.cross_entropy(model_output, target)
39+
return model_output[:,self.category] - torch.nn.functional.cross_entropy(model_output, target)
4040

4141

4242
class BinaryClassifierOutputTarget:

0 commit comments

Comments
 (0)