Skip to content

Commit 54fa32a

Browse files
authored
🐛 Add mypy Type Check tools/wsi_registration.py (#831)
- Fix type errors
1 parent 647d30b commit 54fa32a

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tiatoolbox/tools/registration/wsi_registration.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ class DFBRFeatureExtractor(torch.nn.Module):
332332
333333
"""
334334

335-
def __init__(self: torch.nn.Module) -> None:
335+
def __init__(self: DFBRFeatureExtractor) -> None:
336336
"""Initialize :class:`DFBRFeatureExtractor`."""
337337
super().__init__()
338338
output_layers_id: list[str] = ["16", "23", "30"]
@@ -434,8 +434,8 @@ class DFBRegister:
434434
def __init__(self: DFBRegister, patch_size: tuple[int, int] = (224, 224)) -> None:
435435
"""Initialize :class:`DFBRegister`."""
436436
self.patch_size = patch_size
437-
self.x_scale: list[float] = []
438-
self.y_scale: list[float] = []
437+
self.x_scale: np.ndarray
438+
self.y_scale: np.ndarray
439439
self.feature_extractor = DFBRFeatureExtractor()
440440

441441
# Make this function private when full pipeline is implemented.
@@ -796,7 +796,7 @@ def find_points_inside_boundary(mask: np.ndarray, points: np.ndarray) -> np.ndar
796796
return PatchExtractor.filter_coordinates(
797797
mask_reader,
798798
bbox_coord,
799-
mask.shape[::-1],
799+
(mask.shape[1], mask.shape[0]),
800800
)
801801

802802
def filtering_matching_points(
@@ -1521,21 +1521,21 @@ def get_patch_dimensions(
15211521
"""
15221522
width, height = size[0], size[1]
15231523

1524-
x = [
1524+
x_info = [
15251525
np.linspace(1, width, width, endpoint=True),
15261526
np.ones(height) * width,
15271527
np.linspace(1, width, width, endpoint=True),
15281528
np.ones(height),
15291529
]
1530-
x = np.array(list(itertools.chain.from_iterable(x)))
1530+
x = np.array(list(itertools.chain.from_iterable(x_info)))
15311531

1532-
y = [
1532+
y_info = [
15331533
np.ones(width),
15341534
np.linspace(1, height, height, endpoint=True),
15351535
np.ones(width) * height,
15361536
np.linspace(1, height, height, endpoint=True),
15371537
]
1538-
y = np.array(list(itertools.chain.from_iterable(y)))
1538+
y = np.array(list(itertools.chain.from_iterable(y_info)))
15391539

15401540
points = np.array([x, y]).transpose()
15411541
transform = transform * [[1, 1, 0], [1, 1, 0], [1, 1, 1]] # remove translation

0 commit comments

Comments
 (0)