From e42a7867c197202716a80c1c312a130f941d4269 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Wed, 14 Feb 2024 16:49:59 +0900 Subject: [PATCH] improve type hints --- pytorch_grad_cam/base_cam.py | 13 +++++++------ pytorch_grad_cam/utils/image.py | 7 ++++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py index f08cd12bb..a04ec01e7 100644 --- a/pytorch_grad_cam/base_cam.py +++ b/pytorch_grad_cam/base_cam.py @@ -1,4 +1,5 @@ import numpy as np +import numpy.typing as npt import torch import ttach as tta from typing import Callable, List, Tuple, Optional @@ -12,7 +13,7 @@ class BaseCAM: def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], - reshape_transform: Callable = None, + reshape_transform: Optional[Callable] = None, compute_input_gradient: bool = False, uses_gradients: bool = True, tta_transforms: Optional[tta.Compose] = None) -> None: @@ -71,8 +72,8 @@ def get_cam_image(self, def forward(self, input_tensor: torch.Tensor, - targets: List[torch.nn.Module], - eigen_smooth: bool = False) -> np.ndarray: + targets: Optional[List[torch.nn.Module]], + eigen_smooth: bool = False) -> npt.NDArray[np.float32]: input_tensor = input_tensor.to(self.device) @@ -156,7 +157,7 @@ def aggregate_multi_layers( def forward_augmentation_smoothing(self, input_tensor: torch.Tensor, - targets: List[torch.nn.Module], + targets: Optional[List[torch.nn.Module]], eigen_smooth: bool = False) -> np.ndarray: cams = [] for transform in self.tta_transforms: @@ -180,9 +181,9 @@ def forward_augmentation_smoothing(self, def __call__(self, input_tensor: torch.Tensor, - targets: List[torch.nn.Module] = None, + targets: Optional[List[torch.nn.Module]]= None, aug_smooth: bool = False, - eigen_smooth: bool = False) -> np.ndarray: + eigen_smooth: bool = False) -> npt.NDArray[np.float32]: # Smooth the CAM result with test time augmentation if aug_smooth is True: diff --git a/pytorch_grad_cam/utils/image.py b/pytorch_grad_cam/utils/image.py index 34d92ba6f..3dfb38322 100644 --- a/pytorch_grad_cam/utils/image.py +++ b/pytorch_grad_cam/utils/image.py @@ -3,6 +3,7 @@ from matplotlib.lines import Line2D import cv2 import numpy as np +import numpy.typing as npt import torch from torchvision.transforms import Compose, Normalize, ToTensor from typing import List, Dict @@ -30,11 +31,11 @@ def deprocess_image(img): return np.uint8(img * 255) -def show_cam_on_image(img: np.ndarray, - mask: np.ndarray, +def show_cam_on_image(img: npt.NDArray[np.float16] | npt.NDArray[np.float32] | npt.NDArray[np.float64], + mask: npt.NDArray[np.float16] | npt.NDArray[np.float32] | npt.NDArray[np.float64], use_rgb: bool = False, colormap: int = cv2.COLORMAP_JET, - image_weight: float = 0.5) -> np.ndarray: + image_weight: float = 0.5) -> npt.NDArray[np.uint8]: """ This function overlays the cam mask on the image as an heatmap. By default the heatmap is in BGR format.