|
10 | 10 | """
|
11 | 11 | class ShapleyCAM(BaseCAM):
|
12 | 12 | def __init__(self, model, target_layers,
|
13 |
| - reshape_transform=None, detach=False): |
| 13 | + reshape_transform=None): |
14 | 14 | super(
|
15 | 15 | ShapleyCAM,
|
16 | 16 | self).__init__(
|
17 |
| - model, |
18 |
| - target_layers, |
19 |
| - reshape_transform, |
20 |
| - detach = detach) |
21 |
| - |
22 |
| - def forward( |
23 |
| - self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False |
24 |
| - ) -> np.ndarray: |
25 |
| - input_tensor = input_tensor.to(self.device) |
26 |
| - |
27 |
| - input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True) |
28 |
| - |
29 |
| - self.outputs = outputs = self.activations_and_grads(input_tensor) |
30 |
| - |
31 |
| - if targets is None: |
32 |
| - target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1) |
33 |
| - targets = [ClassifierOutputTarget(category) for category in target_categories] |
34 |
| - |
35 |
| - if self.uses_gradients: |
36 |
| - self.model.zero_grad() |
37 |
| - loss = sum([target(output) for target, output in zip(targets, outputs)]) |
38 |
| - # keep the graph, create_graph = True is needed for hvp |
39 |
| - torch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True) |
40 |
| - |
41 |
| - # In most of the saliency attribution papers, the saliency is |
42 |
| - # computed with a single target layer. |
43 |
| - # Commonly it is the last convolutional layer. |
44 |
| - # Here we support passing a list with multiple target layers. |
45 |
| - # It will compute the saliency image for every image, |
46 |
| - # and then aggregate them (with a default mean aggregation). |
47 |
| - # This gives you more flexibility in case you just want to |
48 |
| - # use all conv layers for example, all Batchnorm layers, |
49 |
| - # or something else. |
50 |
| - cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth) |
51 |
| - return self.aggregate_multi_layers(cam_per_layer) |
52 |
| - |
| 17 | + model = model, |
| 18 | + target_layers = target_layers, |
| 19 | + reshape_transform = reshape_transform, |
| 20 | + compute_input_gradient = True, |
| 21 | + uses_gradients = True, |
| 22 | + detach = False) |
53 | 23 |
|
54 | 24 | def get_cam_weights(self,
|
55 | 25 | input_tensor,
|
|
0 commit comments