Skip to content

Commit a9e34c3

Browse files
Jiaqi-Lvshaneahmedpre-commit-ci[bot]
authored
🔧 mypy Type Check tiatoolbox/models (#912)
Add type checks to: - `tiatoolbox/models/__init__.py` - `tiatoolbox/models/models_abc.py` - `tiatoolbox/models/architecture/__init__.py` - `tiatoolbox/models/architecture/utils.py` * fix bug * add architecture/utils.py * fix model_abc.py * fix utils.py * try to fix pytest * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * pin glymur version < 0.14 * pin glymur version < 0.14 * improve test coverage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ruff --------- Co-authored-by: Shan E Ahmed Raza <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1f2282b commit a9e34c3

File tree

5 files changed

+67
-30
lines changed

5 files changed

+67
-30
lines changed

.github/workflows/mypy-type-check.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,8 @@ jobs:
4646
tiatoolbox/tools \
4747
tiatoolbox/data \
4848
tiatoolbox/annotation \
49-
tiatoolbox/cli/common.py
49+
tiatoolbox/cli/common.py \
50+
tiatoolbox/models/__init__.py \
51+
tiatoolbox/models/models_abc.py \
52+
tiatoolbox/models/architecture/__init__.py \
53+
tiatoolbox/models/architecture/utils.py \

tests/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,10 +1614,17 @@ def test_fetch_pretrained_weights(tmp_path: Path) -> None:
16141614
fetch_pretrained_weights(model_name="mobilenet_v3_small-pcam", save_path=file_path)
16151615
assert file_path.exists()
16161616
assert file_path.stat().st_size > 0
1617+
file_path.unlink()
16171618

16181619
with pytest.raises(ValueError, match="does not exist"):
16191620
fetch_pretrained_weights("abc", file_path)
16201621

1622+
# Test save_path is str
1623+
file_path_str = str(file_path)
1624+
file_path = fetch_pretrained_weights("mobilenet_v3_small-pcam", file_path_str)
1625+
assert Path(file_path).exists()
1626+
assert Path(file_path).stat().st_size > 0
1627+
16211628

16221629
def test_imwrite(tmp_path: Path) -> NoReturn:
16231630
"""Create a temporary file path."""

tiatoolbox/models/architecture/__init__.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from __future__ import annotations
44

5-
import os
5+
from pathlib import Path
66
from pydoc import locate
7-
from typing import TYPE_CHECKING, Optional, Union
7+
from typing import TYPE_CHECKING
88

99
import torch
1010

@@ -13,8 +13,6 @@
1313
from tiatoolbox.utils import download_data
1414

1515
if TYPE_CHECKING: # pragma: no cover
16-
from pathlib import Path
17-
1816
from tiatoolbox.models.models_abc import IOConfigABC
1917

2018

@@ -53,10 +51,14 @@ def fetch_pretrained_weights(
5351

5452
if save_path is None:
5553
file_name = info["url"].split("/")[-1]
56-
save_path = rcParam["TIATOOLBOX_HOME"] / "models" / file_name
54+
processed_save_path = rcParam["TIATOOLBOX_HOME"] / "models" / file_name
55+
elif type(save_path) is str:
56+
processed_save_path = Path(save_path)
57+
else:
58+
processed_save_path = save_path
5759

58-
download_data(info["url"], save_path=save_path, overwrite=overwrite)
59-
return save_path
60+
download_data(info["url"], save_path=processed_save_path, overwrite=overwrite)
61+
return processed_save_path
6062

6163

6264
def get_pretrained_model(
@@ -129,9 +131,15 @@ def get_pretrained_model(
129131
info = PRETRAINED_INFO[pretrained_model]
130132

131133
arch_info = info["architecture"]
132-
creator = locate(f"tiatoolbox.models.architecture.{arch_info['class']}")
133-
134-
model = creator(**arch_info["kwargs"])
134+
model_class_info = arch_info["class"]
135+
model_module_name = str(".".join(model_class_info.split(".")[:-1]))
136+
model_name = str(model_class_info.split(".")[-1])
137+
138+
# Import module containing required model class
139+
arch_module = locate(f"tiatoolbox.models.architecture.{model_module_name}")
140+
# Get model class form module
141+
model_class = getattr(arch_module, model_name)
142+
model = model_class(**arch_info["kwargs"])
135143
# TODO(TBC): Dictionary of dataset specific or transformation? # noqa: FIX002,TD003
136144
if "dataset" in info:
137145
# ! this is a hack currently, need another PR to clean up
@@ -152,7 +160,12 @@ def get_pretrained_model(
152160
# !
153161

154162
io_info = info["ioconfig"]
155-
creator = locate(f"tiatoolbox.models.engine.{io_info['class']}")
163+
io_class_info = io_info["class"]
164+
io_module_name = str(".".join(io_class_info.split(".")[:-1]))
165+
io_class_name = str(io_class_info.split(".")[-1])
166+
167+
engine_module = locate(f"tiatoolbox.models.engine.{io_module_name}")
168+
engine_class = getattr(engine_module, io_class_name)
156169

157-
iostate = creator(**io_info["kwargs"])
170+
iostate = engine_class(**io_info["kwargs"])
158171
return model, iostate

tiatoolbox/models/architecture/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import sys
6+
from typing import cast
67

78
import numpy as np
89
import torch
@@ -45,7 +46,7 @@ def is_torch_compile_compatible() -> bool:
4546

4647

4748
def compile_model(
48-
model: nn.Module | None = None,
49+
model: nn.Module,
4950
*,
5051
mode: str = "default",
5152
) -> nn.Module:
@@ -97,12 +98,12 @@ def compile_model(
9798
)
9899
return model
99100

100-
return torch.compile(model, mode=mode) # pragma: no cover
101+
return cast(nn.Module, torch.compile(model, mode=mode)) # pragma: no cover
101102

102103

103104
def centre_crop(
104-
img: np.ndarray | torch.tensor,
105-
crop_shape: np.ndarray | torch.tensor,
105+
img: np.ndarray | torch.Tensor,
106+
crop_shape: np.ndarray | torch.Tensor | tuple,
106107
data_format: str = "NCHW",
107108
) -> np.ndarray | torch.Tensor:
108109
"""A function to center crop image with given crop shape.
@@ -136,8 +137,8 @@ def centre_crop(
136137

137138

138139
def centre_crop_to_shape(
139-
x: np.ndarray | torch.tensor,
140-
y: np.ndarray | torch.tensor,
140+
x: np.ndarray | torch.Tensor,
141+
y: np.ndarray | torch.Tensor,
141142
data_format: str = "NCHW",
142143
) -> np.ndarray | torch.Tensor:
143144
"""A function to center crop image to shape.
@@ -200,6 +201,7 @@ def __init__(self: UpSample2x) -> None:
200201
"""Initialize :class:`UpSample2x`."""
201202
super().__init__()
202203
# correct way to create constant within module
204+
self.unpool_mat: torch.Tensor
203205
self.register_buffer(
204206
"unpool_mat",
205207
torch.from_numpy(np.ones((2, 2), dtype="float32")),

tiatoolbox/models/models_abc.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77

88
import torch
99
import torch._dynamo
10-
from torch import device as torch_device
1110

1211
torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001
1312

14-
1513
if TYPE_CHECKING: # pragma: no cover
1614
from pathlib import Path
1715

@@ -57,8 +55,8 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module:
5755
# DataParallel work only for cuda
5856
model = torch.nn.DataParallel(model)
5957

60-
device = torch.device(device)
61-
return model.to(device)
58+
torch_device = torch.device(device)
59+
return model.to(torch_device)
6260

6361

6462
class ModelABC(ABC, torch.nn.Module):
@@ -72,7 +70,9 @@ def __init__(self: ModelABC) -> None:
7270

7371
@abstractmethod
7472
# This is generic abc, else pylint will complain
75-
def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None:
73+
def forward(
74+
self: ModelABC, *args: tuple[Any, ...], **kwargs: dict
75+
) -> None | torch.Tensor:
7676
"""Torch method, this contains logic for using layers defined in init."""
7777
... # pragma: no cover
7878

@@ -175,27 +175,38 @@ def postproc_func(self: ModelABC, func: Callable) -> None:
175175
else:
176176
self._postproc = func
177177

178-
def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module:
178+
def to( # type: ignore[override]
179+
self: ModelABC,
180+
device: str = "cpu",
181+
dtype: torch.dtype | None = None,
182+
*,
183+
non_blocking: bool = False,
184+
) -> ModelABC | torch.nn.DataParallel[ModelABC]:
179185
"""Transfers model to cpu/gpu.
180186
181187
Args:
182188
model (torch.nn.Module):
183189
PyTorch defined model.
184190
device (str):
185191
Transfers model to the specified device. Default is "cpu".
192+
dtype (:class:`torch.dtype`): the desired floating point or complex dtype of
193+
the parameters and buffers in this module.
194+
non_blocking (bool): When set, it tries to convert/move asynchronously
195+
with respect to the host if possible, e.g., moving CPU Tensors with
196+
pinned memory to CUDA devices.
186197
187198
Returns:
188-
torch.nn.Module:
199+
torch.nn.Module | torch.nn.DataParallel:
189200
The model after being moved to cpu/gpu.
190201
191202
"""
192-
device = torch_device(device)
193-
model = super().to(device)
203+
torch_device = torch.device(device)
204+
model = super().to(torch_device, dtype=dtype, non_blocking=non_blocking)
194205

195206
# If target device istorch.cuda and more
196207
# than one GPU is available, use DataParallel
197-
if device.type == "cuda" and torch.cuda.device_count() > 1:
198-
model = torch.nn.DataParallel(model) # pragma: no cover
208+
if torch_device.type == "cuda" and torch.cuda.device_count() > 1:
209+
return torch.nn.DataParallel(model) # pragma: no cover
199210

200211
return model
201212

0 commit comments

Comments
 (0)