Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions tiatoolbox/tools/patchextraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import Unpack

from tiatoolbox import logger
from tiatoolbox.annotation.storage import AnnotationStore
from tiatoolbox.utils import misc
from tiatoolbox.utils.exceptions import FileNotSupportedError, MethodNotSupportedError
from tiatoolbox.utils.visualization import AnnotationRenderer
Expand All @@ -19,7 +20,6 @@

from pandas import DataFrame

from tiatoolbox.annotation.storage import AnnotationStore
from tiatoolbox.type_hints import Resolution, Units


Expand Down Expand Up @@ -237,7 +237,9 @@ def __init__(

if input_mask is None:
self.mask = None
elif isinstance(input_mask, str) and input_mask.endswith(".db"):
elif (isinstance(input_mask, str) and input_mask.endswith(".db")) or isinstance(
input_mask, AnnotationStore
):
# input_mask is an annotation store
renderer = AnnotationRenderer(
max_scale=10000, edge_thickness=0, where=store_filter
Expand Down Expand Up @@ -670,7 +672,12 @@ def __init__( # noqa: PLR0913
self: SlidingWindowPatchExtractor,
input_img: str | Path | np.ndarray | wsireader.WSIReader,
patch_size: int | tuple[int, int],
input_mask: str | Path | np.ndarray | wsireader.VirtualWSIReader | None = None,
input_mask: str
| Path
| np.ndarray
| wsireader.VirtualWSIReader
| AnnotationStore
| None = None,
resolution: Resolution = 0,
units: Units = "level",
stride: int | tuple[int, int] | None = None,
Expand Down
17 changes: 11 additions & 6 deletions tiatoolbox/tools/registration/wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, cast

import cv2
import numpy as np
Expand Down Expand Up @@ -338,19 +338,24 @@ def __init__(self: DFBRFeatureExtractor) -> None:
super().__init__()
output_layers_id: list[str] = ["16", "23", "30"]
output_layers_key: list[str] = ["block3_pool", "block4_pool", "block5_pool"]
self.features: dict = dict.fromkeys(output_layers_key, None)
self.pretrained: torch.nn.Sequential = compile_model(
self.features: dict[str, torch.Tensor] = dict.fromkeys(
output_layers_key, torch.Tensor()
)

compiled_model = compile_model(
torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1),
mode=rcParam["torch_compile_mode"],
).features
)
self.pretrained = cast("torch.nn.Module", compiled_model.features)

self.f_hooks = [
getattr(self.pretrained, layer).register_forward_hook(
self.forward_hook(output_layers_key[i]),
)
for i, layer in enumerate(output_layers_id)
]

def forward_hook(self: torch.nn.Module, layer_name: str) -> Callable:
def forward_hook(self: DFBRFeatureExtractor, layer_name: str) -> Callable:
"""Register a hook.

Args:
Expand Down Expand Up @@ -386,7 +391,7 @@ def hook(

return hook

def forward(self: torch.nn.Module, x: torch.Tensor) -> dict[str, torch.Tensor]:
def forward(self: DFBRFeatureExtractor, x: torch.Tensor) -> dict[str, torch.Tensor]:
"""Forward pass for feature extraction.

Args:
Expand Down
8 changes: 6 additions & 2 deletions tiatoolbox/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def imresize(
img_channels = [
cv2.resize(
src=img[..., ch],
dsize=output_size_array,
dsize=(output_size_array[0], output_size_array[1]),
interpolation=cv2_interpolation,
)[
...,
Expand All @@ -199,7 +199,11 @@ def imresize(
]
return np.concatenate(img_channels, axis=-1)

return cv2.resize(src=img, dsize=output_size_array, interpolation=cv2_interpolation)
return cv2.resize(
src=img,
dsize=(output_size_array[0], output_size_array[1]),
interpolation=cv2_interpolation,
)


def rgb2od(img: np.ndarray) -> np.ndarray:
Expand Down
8 changes: 5 additions & 3 deletions tiatoolbox/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,14 +475,15 @@ def overlay_prediction_contours(
inst_colours_array = inst_colours_array.astype(np.uint8)

for idx, [_, inst_info] in enumerate(inst_dict.items()):
inst_contour = inst_info["contour"]
inst_contour: np.ndarray = inst_info["contour"]
if "type" in inst_info and type_colours is not None:
inst_colour = type_colours[inst_info["type"]][1]
else:
inst_colour = (inst_colours_array[idx]).tolist()
contours: list[np.ndarray] = [np.array(inst_contour)]
cv2.drawContours(
overlay,
[np.array(inst_contour)],
contours,
-1,
inst_colour,
line_thickness,
Expand Down Expand Up @@ -881,9 +882,10 @@ def render_line(
top_left,
scale,
)
pts: list[np.ndarray] = [np.array(cnt)]
cv2.polylines(
tile,
[np.array(cnt)],
pts,
isClosed=False,
color=col,
thickness=3,
Expand Down