Skip to content

Commit 25067d1

Browse files
authored
Update ablation_cam.py (#440)
Updated the comments for better understanding and suggested some syntax modification.
1 parent 09ac162 commit 25067d1

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

pytorch_grad_cam/ablation_cam.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
""" Implementation of AblationCAM
1111
https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf
1212
13-
Ablate individual activations, and then measure the drop in the target score.
13+
Ablate individual activations, and then measure the drop in the target scores.
1414
1515
In the current implementation, the target layer activations is cached, so it won't be re-computed.
1616
However layers before it, if any, will not be cached.
@@ -88,8 +88,8 @@ def get_cam_weights(self,
8888
[target(output).cpu().item() for target, output in zip(targets, outputs)])
8989

9090
# Replace the layer with the ablation layer.
91-
# When we finish, we will replace it back, so the original model is
92-
# unchanged.
91+
# When we finish, we will replace it back, so the
92+
# original model is unchanged.
9393
ablation_layer = self.ablation_layer
9494
replace_layer_recursive(self.model, target_layer, ablation_layer)
9595

@@ -122,9 +122,9 @@ def get_cam_weights(self,
122122
# Change the state of the ablation layer so it ablates the next channels.
123123
# TBD: Move this into the ablation layer forward pass.
124124
ablation_layer.set_next_batch(
125-
input_batch_index=batch_index,
126-
activations=self.activations,
127-
num_channels_to_ablate=batch_tensor.size(0))
125+
input_batch_index = batch_index,
126+
activations = self.activations,
127+
num_channels_to_ablate = batch_tensor.size(0))
128128
score = [target(o).cpu().item()
129129
for o in self.model(batch_tensor)]
130130
new_scores.extend(score)
@@ -145,4 +145,5 @@ def get_cam_weights(self,
145145

146146
# Replace the model back to the original state
147147
replace_layer_recursive(self.model, ablation_layer, target_layer)
148+
# Returning the weights from new_scores
148149
return weights

0 commit comments

Comments
 (0)