diff --git a/pytorch_grad_cam/ablation_cam.py b/pytorch_grad_cam/ablation_cam.py index 252b5b078..7b30eea11 100644 --- a/pytorch_grad_cam/ablation_cam.py +++ b/pytorch_grad_cam/ablation_cam.py @@ -1,11 +1,12 @@ +from typing import Callable, List, Optional + import numpy as np import torch import tqdm -from typing import Callable, List + +from pytorch_grad_cam.ablation_layer import AblationLayer from pytorch_grad_cam.base_cam import BaseCAM from pytorch_grad_cam.utils.find_layers import replace_layer_recursive -from pytorch_grad_cam.ablation_layer import AblationLayer - """ Implementation of AblationCAM https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf @@ -25,13 +26,15 @@ class AblationCAM(BaseCAM): - def __init__(self, - model: torch.nn.Module, - target_layers: List[torch.nn.Module], - reshape_transform: Callable = None, - ablation_layer: torch.nn.Module = AblationLayer(), - batch_size: int = 32, - ratio_channels_to_ablate: float = 1.0) -> None: + def __init__( + self, + model: torch.nn.Module, + target_layers: List[torch.nn.Module], + reshape_transform: Optional[Callable] = None, + ablation_layer: torch.nn.Module = AblationLayer(), + batch_size: int = 32, + ratio_channels_to_ablate: float = 1.0, + ) -> None: super(AblationCAM, self).__init__(model, target_layers, diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py index 484e8865f..147f2752e 100644 --- a/pytorch_grad_cam/base_cam.py +++ b/pytorch_grad_cam/base_cam.py @@ -15,7 +15,7 @@ def __init__( self, model: torch.nn.Module, target_layers: List[torch.nn.Module], - reshape_transform: Callable = None, + reshape_transform: Optional[Callable] = None, compute_input_gradient: bool = False, uses_gradients: bool = True, tta_transforms: Optional[tta.Compose] = None, @@ -91,7 +91,10 @@ def get_cam_image( return cam def forward( - self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False + self, + input_tensor: torch.Tensor, + targets: Optional[List[torch.nn.Module]], + eigen_smooth: bool = False, ) -> np.ndarray: input_tensor = input_tensor.to(self.device) @@ -129,7 +132,9 @@ def forward( cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth) return self.aggregate_multi_layers(cam_per_layer) - def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]: + def get_target_width_height( + self, input_tensor: torch.Tensor + ) -> Tuple[int, int] | Tuple[int, int, int]: if len(input_tensor.shape) == 4: width, height = input_tensor.size(-1), input_tensor.size(-2) return width, height @@ -175,7 +180,10 @@ def aggregate_multi_layers(self, cam_per_target_layer: np.ndarray) -> np.ndarray return scale_cam_image(result) def forward_augmentation_smoothing( - self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False + self, + input_tensor: torch.Tensor, + targets: Optional[List[torch.nn.Module]], + eigen_smooth: bool = False, ) -> np.ndarray: cams = [] for transform in self.tta_transforms: @@ -198,7 +206,7 @@ def forward_augmentation_smoothing( def __call__( self, input_tensor: torch.Tensor, - targets: List[torch.nn.Module] = None, + targets: Optional[List[torch.nn.Module]], aug_smooth: bool = False, eigen_smooth: bool = False, ) -> np.ndarray: diff --git a/pytorch_grad_cam/feature_factorization/deep_feature_factorization.py b/pytorch_grad_cam/feature_factorization/deep_feature_factorization.py index b9db2c3e3..60a0343b0 100644 --- a/pytorch_grad_cam/feature_factorization/deep_feature_factorization.py +++ b/pytorch_grad_cam/feature_factorization/deep_feature_factorization.py @@ -45,12 +45,13 @@ class DeepFeatureFactorization: and to the input tensor width and height. """ - def __init__(self, - model: torch.nn.Module, - target_layer: torch.nn.Module, - reshape_transform: Callable = None, - computation_on_concepts=None - ): + def __init__( + self, + model: torch.nn.Module, + target_layer: torch.nn.Module, + reshape_transform: Optional[Callable] = None, + computation_on_concepts=None, + ): self.model = model self.computation_on_concepts = computation_on_concepts self.activations_and_grads = ActivationsAndGradients( @@ -95,14 +96,16 @@ def __exit__(self, exc_type, exc_value, exc_tb): return True -def run_dff_on_image(model: torch.nn.Module, - target_layer: torch.nn.Module, - classifier: torch.nn.Module, - img_pil: Image, - img_tensor: torch.Tensor, - reshape_transform=Optional[Callable], - n_components: int = 5, - top_k: int = 2) -> np.ndarray: +def run_dff_on_image( + model: torch.nn.Module, + target_layer: torch.nn.Module, + classifier: torch.nn.Module, + img_pil: Image.Image, + img_tensor: torch.Tensor, + reshape_transform=Optional[Callable], + n_components: int = 5, + top_k: int = 2, +) -> np.ndarray: """ Helper function to create a Deep Feature Factorization visualization for a single image. TBD: Run this on a batch with several images. """ diff --git a/pytorch_grad_cam/finer_cam.py b/pytorch_grad_cam/finer_cam.py index f5232f4a6..a351c3a10 100644 --- a/pytorch_grad_cam/finer_cam.py +++ b/pytorch_grad_cam/finer_cam.py @@ -1,10 +1,12 @@ +from typing import Callable, List, Optional + import numpy as np import torch -from typing import List, Callable -from pytorch_grad_cam.base_cam import BaseCAM + from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import FinerWeightedTarget + class FinerCAM: def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], reshape_transform: Callable = None, base_method=GradCAM): self.base_cam = base_method(model, target_layers, reshape_transform) @@ -14,9 +16,15 @@ def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - def forward(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module] = None, eigen_smooth: bool = False, - alpha: float = 1, comparison_categories: List[int] = [1, 2, 3], target_idx: int = None - ) -> np.ndarray: + def forward( + self, + input_tensor: torch.Tensor, + targets: Optional[List[torch.nn.Module]] = None, + eigen_smooth: bool = False, + alpha: float = 1, + comparison_categories: List[int] = [1, 2, 3], + target_idx: Optional[int] = None, + ) -> np.ndarray: input_tensor = input_tensor.to(self.base_cam.device) if self.compute_input_gradient: diff --git a/pytorch_grad_cam/utils/image.py b/pytorch_grad_cam/utils/image.py index 44f164803..83548b367 100644 --- a/pytorch_grad_cam/utils/image.py +++ b/pytorch_grad_cam/utils/image.py @@ -1,8 +1,6 @@ -import math -from typing import Dict, List +from typing import Dict, List, Optional import cv2 -import matplotlib import numpy as np import torch from matplotlib import pyplot as plt @@ -82,11 +80,13 @@ def create_labels_legend(concept_scores: np.ndarray, return concept_labels_topk -def show_factorization_on_image(img: np.ndarray, - explanations: np.ndarray, - colors: List[np.ndarray] = None, - image_weight: float = 0.5, - concept_labels: List = None) -> np.ndarray: +def show_factorization_on_image( + img: np.ndarray, + explanations: np.ndarray, + colors: Optional[List[np.ndarray]] = None, + image_weight: float = 0.5, + concept_labels: Optional[list] = None, +) -> np.ndarray: """ Color code the different component heatmaps on top of the image. Every component color code will be magnified according to the heatmap itensity (by modifying the V channel in the HSV color space),