Skip to content

Commit 9f2d539

Browse files
comments
1 parent c25ca78 commit 9f2d539

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

pytorch_grad_cam/base_cam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def forward(
111111
loss.backward(retain_graph=True)
112112
else:
113113
# keep the computational graph, create_graph = True is needed for hvp
114-
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
114+
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
115115
# When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
116116
# loss.backward(retain_graph=True, create_graph=True)
117117
if 'hpu' in str(self.device):

pytorch_grad_cam/shapley_cam.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
"""
8-
Weighting the activation maps using Gradient and Hessian-Vector Product.
8+
Weights the activation maps using the gradient and Hessian-Vector product.
99
This method (https://arxiv.org/abs/2501.06261) reinterpret CAM methods (include GradCAM, HiResCAM and the original CAM) from a Shapley value perspective.
1010
"""
1111
class ShapleyCAM(BaseCAM):
@@ -51,12 +51,10 @@ def get_cam_weights(self,
5151
if len(activations.shape) == 4:
5252
weight = np.mean(weight, axis=(2, 3))
5353
return weight
54-
5554
# 3D image
5655
elif len(activations.shape) == 5:
5756
weight = np.mean(weight, axis=(2, 3, 4))
5857
return weight
59-
6058
else:
6159
raise ValueError("Invalid grads shape."
6260
"Shape of grads should be 4 (2D image) or 5 (3D image).")

0 commit comments

Comments
 (0)