Skip to content

Commit 445d323

Browse files
comments
1 parent 0b85cba commit 445d323

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pytorch_grad_cam/shapley_cam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def get_cam_weights(self,
7575
retain_graph=False,
7676
allow_unused=True
7777
)[0]
78+
# print(torch.max(hvp[0]).item()) # verify that hvp is not all zeros
7879
if hvp is None:
7980
hvp = torch.tensor(0).to(self.device)
8081
elif self.activations_and_grads.reshape_transform is not None:
@@ -83,7 +84,7 @@ def get_cam_weights(self,
8384
if self.activations_and_grads.reshape_transform is not None:
8485
activations = self.activations_and_grads.reshape_transform(activations)
8586
grads = self.activations_and_grads.reshape_transform(grads)
86-
weight = (grads - 0.5*hvp).cpu().detach().numpy()
87+
weight = (grads - 0.5 * hvp).cpu().detach().numpy()
8788
activations = activations.cpu().detach().numpy()
8889
grads = grads.cpu().detach().numpy()
8990

0 commit comments

Comments
 (0)