diff --git a/tiatoolbox/tools/patchextraction.py b/tiatoolbox/tools/patchextraction.py index 3de867a19..a6caf082c 100644 --- a/tiatoolbox/tools/patchextraction.py +++ b/tiatoolbox/tools/patchextraction.py @@ -9,6 +9,7 @@ from typing_extensions import Unpack from tiatoolbox import logger +from tiatoolbox.annotation.storage import AnnotationStore from tiatoolbox.utils import misc from tiatoolbox.utils.exceptions import FileNotSupportedError, MethodNotSupportedError from tiatoolbox.utils.visualization import AnnotationRenderer @@ -19,7 +20,6 @@ from pandas import DataFrame - from tiatoolbox.annotation.storage import AnnotationStore from tiatoolbox.type_hints import Resolution, Units @@ -237,7 +237,9 @@ def __init__( if input_mask is None: self.mask = None - elif isinstance(input_mask, str) and input_mask.endswith(".db"): + elif (isinstance(input_mask, str) and input_mask.endswith(".db")) or isinstance( + input_mask, AnnotationStore + ): # input_mask is an annotation store renderer = AnnotationRenderer( max_scale=10000, edge_thickness=0, where=store_filter @@ -670,7 +672,12 @@ def __init__( # noqa: PLR0913 self: SlidingWindowPatchExtractor, input_img: str | Path | np.ndarray | wsireader.WSIReader, patch_size: int | tuple[int, int], - input_mask: str | Path | np.ndarray | wsireader.VirtualWSIReader | None = None, + input_mask: str + | Path + | np.ndarray + | wsireader.VirtualWSIReader + | AnnotationStore + | None = None, resolution: Resolution = 0, units: Units = "level", stride: int | tuple[int, int] | None = None, diff --git a/tiatoolbox/tools/registration/wsi_registration.py b/tiatoolbox/tools/registration/wsi_registration.py index d0a6fc38a..fdca5200a 100644 --- a/tiatoolbox/tools/registration/wsi_registration.py +++ b/tiatoolbox/tools/registration/wsi_registration.py @@ -3,7 +3,7 @@ from __future__ import annotations import itertools -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, cast import cv2 import numpy as np @@ -338,11 +338,16 @@ def __init__(self: DFBRFeatureExtractor) -> None: super().__init__() output_layers_id: list[str] = ["16", "23", "30"] output_layers_key: list[str] = ["block3_pool", "block4_pool", "block5_pool"] - self.features: dict = dict.fromkeys(output_layers_key, None) - self.pretrained: torch.nn.Sequential = compile_model( + self.features: dict[str, torch.Tensor] = dict.fromkeys( + output_layers_key, torch.Tensor() + ) + + compiled_model = compile_model( torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1), mode=rcParam["torch_compile_mode"], - ).features + ) + self.pretrained = cast("torch.nn.Module", compiled_model.features) + self.f_hooks = [ getattr(self.pretrained, layer).register_forward_hook( self.forward_hook(output_layers_key[i]), @@ -350,7 +355,7 @@ def __init__(self: DFBRFeatureExtractor) -> None: for i, layer in enumerate(output_layers_id) ] - def forward_hook(self: torch.nn.Module, layer_name: str) -> Callable: + def forward_hook(self: DFBRFeatureExtractor, layer_name: str) -> Callable: """Register a hook. Args: @@ -386,7 +391,7 @@ def hook( return hook - def forward(self: torch.nn.Module, x: torch.Tensor) -> dict[str, torch.Tensor]: + def forward(self: DFBRFeatureExtractor, x: torch.Tensor) -> dict[str, torch.Tensor]: """Forward pass for feature extraction. Args: diff --git a/tiatoolbox/utils/transforms.py b/tiatoolbox/utils/transforms.py index 8fd096b32..8c2817b75 100644 --- a/tiatoolbox/utils/transforms.py +++ b/tiatoolbox/utils/transforms.py @@ -189,7 +189,7 @@ def imresize( img_channels = [ cv2.resize( src=img[..., ch], - dsize=output_size_array, + dsize=(output_size_array[0], output_size_array[1]), interpolation=cv2_interpolation, )[ ..., @@ -199,7 +199,11 @@ def imresize( ] return np.concatenate(img_channels, axis=-1) - return cv2.resize(src=img, dsize=output_size_array, interpolation=cv2_interpolation) + return cv2.resize( + src=img, + dsize=(output_size_array[0], output_size_array[1]), + interpolation=cv2_interpolation, + ) def rgb2od(img: np.ndarray) -> np.ndarray: diff --git a/tiatoolbox/utils/visualization.py b/tiatoolbox/utils/visualization.py index 74a89242c..8d7752d31 100644 --- a/tiatoolbox/utils/visualization.py +++ b/tiatoolbox/utils/visualization.py @@ -475,14 +475,15 @@ def overlay_prediction_contours( inst_colours_array = inst_colours_array.astype(np.uint8) for idx, [_, inst_info] in enumerate(inst_dict.items()): - inst_contour = inst_info["contour"] + inst_contour: np.ndarray = inst_info["contour"] if "type" in inst_info and type_colours is not None: inst_colour = type_colours[inst_info["type"]][1] else: inst_colour = (inst_colours_array[idx]).tolist() + contours: list[np.ndarray] = [np.array(inst_contour)] cv2.drawContours( overlay, - [np.array(inst_contour)], + contours, -1, inst_colour, line_thickness, @@ -881,9 +882,10 @@ def render_line( top_left, scale, ) + pts: list[np.ndarray] = [np.array(cnt)] cv2.polylines( tile, - [np.array(cnt)], + pts, isClosed=False, color=col, thickness=3,