Skip to content

Commit c25ca78

Browse files
delete forward function in shapley_cam.py
1 parent 326300d commit c25ca78

File tree

2 files changed

+14
-38
lines changed

2 files changed

+14
-38
lines changed

pytorch_grad_cam/base_cam.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,13 @@ def forward(
107107
if self.uses_gradients:
108108
self.model.zero_grad()
109109
loss = sum([target(output) for target, output in zip(targets, outputs)])
110-
loss.backward(retain_graph=True)
110+
if self.detach:
111+
loss.backward(retain_graph=True)
112+
else:
113+
# keep the computational graph, create_graph = True is needed for hvp
114+
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
115+
# When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"
116+
# loss.backward(retain_graph=True, create_graph=True)
111117
if 'hpu' in str(self.device):
112118
self.__htcore.mark_step()
113119

pytorch_grad_cam/shapley_cam.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,46 +10,16 @@
1010
"""
1111
class ShapleyCAM(BaseCAM):
1212
def __init__(self, model, target_layers,
13-
reshape_transform=None, detach=False):
13+
reshape_transform=None):
1414
super(
1515
ShapleyCAM,
1616
self).__init__(
17-
model,
18-
target_layers,
19-
reshape_transform,
20-
detach = detach)
21-
22-
def forward(
23-
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
24-
) -> np.ndarray:
25-
input_tensor = input_tensor.to(self.device)
26-
27-
input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
28-
29-
self.outputs = outputs = self.activations_and_grads(input_tensor)
30-
31-
if targets is None:
32-
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
33-
targets = [ClassifierOutputTarget(category) for category in target_categories]
34-
35-
if self.uses_gradients:
36-
self.model.zero_grad()
37-
loss = sum([target(output) for target, output in zip(targets, outputs)])
38-
# keep the graph, create_graph = True is needed for hvp
39-
torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)
40-
41-
# In most of the saliency attribution papers, the saliency is
42-
# computed with a single target layer.
43-
# Commonly it is the last convolutional layer.
44-
# Here we support passing a list with multiple target layers.
45-
# It will compute the saliency image for every image,
46-
# and then aggregate them (with a default mean aggregation).
47-
# This gives you more flexibility in case you just want to
48-
# use all conv layers for example, all Batchnorm layers,
49-
# or something else.
50-
cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
51-
return self.aggregate_multi_layers(cam_per_layer)
52-
17+
model = model,
18+
target_layers = target_layers,
19+
reshape_transform = reshape_transform,
20+
compute_input_gradient = True,
21+
uses_gradients = True,
22+
detach = False)
5323

5424
def get_cam_weights(self,
5525
input_tensor,

0 commit comments

Comments
 (0)