Skip to content

Commit 326300d

Browse files
forward function in shapely_cam.py still needed
This is because the calculation of the Hessian-vector product (HVP) requires the computation graph to be retained, see comments in line 37 or 38.
1 parent b6be00b commit 326300d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pytorch_grad_cam/shapley_cam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Callable, List, Optional, Tuple
22
from pytorch_grad_cam.base_cam import BaseCAM
3+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
34
import torch
45
import numpy as np
56

@@ -34,7 +35,7 @@ def forward(
3435
if self.uses_gradients:
3536
self.model.zero_grad()
3637
loss = sum([target(output) for target, output in zip(targets, outputs)])
37-
# keep the graph
38+
# keep the graph, create_graph = True is needed for hvp
3839
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
3940

4041
# In most of the saliency attribution papers, the saliency is

0 commit comments

Comments
 (0)