Skip to content

Commit 5719fbc

Browse files
committed
🔀 Merge develop into dev-define-engines-abc
2 parents d6fd5fc + a9e34c3 commit 5719fbc

File tree

5 files changed

+66
-28
lines changed

5 files changed

+66
-28
lines changed

.github/workflows/mypy-type-check.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,8 @@ jobs:
4646
tiatoolbox/tools \
4747
tiatoolbox/data \
4848
tiatoolbox/annotation \
49-
tiatoolbox/cli/common.py
49+
tiatoolbox/cli/common.py \
50+
tiatoolbox/models/__init__.py \
51+
tiatoolbox/models/models_abc.py \
52+
tiatoolbox/models/architecture/__init__.py \
53+
tiatoolbox/models/architecture/utils.py \

tests/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,10 +1614,17 @@ def test_fetch_pretrained_weights(tmp_path: Path) -> None:
16141614
fetch_pretrained_weights(model_name="mobilenet_v3_small-pcam", save_path=file_path)
16151615
assert file_path.exists()
16161616
assert file_path.stat().st_size > 0
1617+
file_path.unlink()
16171618

16181619
with pytest.raises(ValueError, match="does not exist"):
16191620
fetch_pretrained_weights("abc", file_path)
16201621

1622+
# Test save_path is str
1623+
file_path_str = str(file_path)
1624+
file_path = fetch_pretrained_weights("mobilenet_v3_small-pcam", file_path_str)
1625+
assert Path(file_path).exists()
1626+
assert Path(file_path).stat().st_size > 0
1627+
16211628

16221629
def test_imwrite(tmp_path: Path) -> NoReturn:
16231630
"""Create a temporary file path."""

tiatoolbox/models/architecture/__init__.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from pathlib import Path
56
from pydoc import locate
67
from typing import TYPE_CHECKING
78

@@ -11,11 +12,10 @@
1112
from tiatoolbox.utils import download_data
1213

1314
if TYPE_CHECKING: # pragma: no cover
14-
from pathlib import Path
15-
1615
import torch
1716

1817
from tiatoolbox.models.engine.io_config import ModelIOConfigABC
18+
from tiatoolbox.models.models_abc import IOConfigABC
1919

2020
__all__ = ["fetch_pretrained_weights", "get_pretrained_model"]
2121
PRETRAINED_INFO = rcParam["pretrained_model_info"]
@@ -52,10 +52,14 @@ def fetch_pretrained_weights(
5252

5353
if save_path is None:
5454
file_name = info["url"].split("/")[-1]
55-
save_path = rcParam["TIATOOLBOX_HOME"] / "models" / file_name
55+
processed_save_path = rcParam["TIATOOLBOX_HOME"] / "models" / file_name
56+
elif type(save_path) is str:
57+
processed_save_path = Path(save_path)
58+
else:
59+
processed_save_path = save_path
5660

57-
download_data(info["url"], save_path=save_path, overwrite=overwrite)
58-
return save_path
61+
download_data(info["url"], save_path=processed_save_path, overwrite=overwrite)
62+
return processed_save_path
5963

6064

6165
def get_pretrained_model(
@@ -128,9 +132,15 @@ def get_pretrained_model(
128132
info = PRETRAINED_INFO[pretrained_model]
129133

130134
arch_info = info["architecture"]
131-
creator = locate(f"tiatoolbox.models.architecture.{arch_info['class']}")
132-
133-
model = creator(**arch_info["kwargs"])
135+
model_class_info = arch_info["class"]
136+
model_module_name = str(".".join(model_class_info.split(".")[:-1]))
137+
model_name = str(model_class_info.split(".")[-1])
138+
139+
# Import module containing required model class
140+
arch_module = locate(f"tiatoolbox.models.architecture.{model_module_name}")
141+
# Get model class form module
142+
model_class = getattr(arch_module, model_name)
143+
model = model_class(**arch_info["kwargs"])
134144
# TODO(TBC): Dictionary of dataset specific or transformation? # noqa: FIX002,TD003
135145
if "dataset" in info:
136146
# ! this is a hack currently, need another PR to clean up
@@ -148,7 +158,12 @@ def get_pretrained_model(
148158
# !
149159

150160
io_info = info["ioconfig"]
151-
creator = locate(f"tiatoolbox.models.engine.{io_info['class']}")
161+
io_class_info = io_info["class"]
162+
io_module_name = str(".".join(io_class_info.split(".")[:-1]))
163+
io_class_name = str(io_class_info.split(".")[-1])
164+
165+
engine_module = locate(f"tiatoolbox.models.engine.{io_module_name}")
166+
engine_class = getattr(engine_module, io_class_name)
152167

153-
ioconfig = creator(**io_info["kwargs"])
168+
ioconfig = engine_class(**io_info["kwargs"])
154169
return model, ioconfig

tiatoolbox/models/architecture/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import sys
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, cast
77

88
import numpy as np
99
import torch
@@ -101,12 +101,12 @@ def compile_model(
101101
)
102102
return model
103103

104-
return torch.compile(model, mode=mode) # pragma: no cover
104+
return cast(nn.Module, torch.compile(model, mode=mode)) # pragma: no cover
105105

106106

107107
def centre_crop(
108-
img: np.ndarray | torch.tensor,
109-
crop_shape: np.ndarray | torch.tensor,
108+
img: np.ndarray | torch.Tensor,
109+
crop_shape: np.ndarray | torch.Tensor | tuple,
110110
data_format: str = "NCHW",
111111
) -> np.ndarray | torch.Tensor:
112112
"""A function to center crop image with given crop shape.
@@ -140,8 +140,8 @@ def centre_crop(
140140

141141

142142
def centre_crop_to_shape(
143-
x: np.ndarray | torch.tensor,
144-
y: np.ndarray | torch.tensor,
143+
x: np.ndarray | torch.Tensor,
144+
y: np.ndarray | torch.Tensor,
145145
data_format: str = "NCHW",
146146
) -> np.ndarray | torch.Tensor:
147147
"""A function to center crop image to shape.
@@ -204,6 +204,7 @@ def __init__(self: UpSample2x) -> None:
204204
"""Initialize :class:`UpSample2x`."""
205205
super().__init__()
206206
# correct way to create constant within module
207+
self.unpool_mat: torch.Tensor
207208
self.register_buffer(
208209
"unpool_mat",
209210
torch.from_numpy(np.ones((2, 2), dtype="float32")),

tiatoolbox/models/models_abc.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77

88
import torch
99
import torch._dynamo
10-
from torch import device as torch_device
1110
from torch import nn
1211

1312
torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001
1413

15-
1614
if TYPE_CHECKING: # pragma: no cover
1715
from pathlib import Path
1816

@@ -58,8 +56,8 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module:
5856
# DataParallel work only for cuda
5957
model = torch.nn.DataParallel(model)
6058

61-
device = torch.device(device)
62-
return model.to(device)
59+
torch_device = torch.device(device)
60+
return model.to(torch_device)
6361

6462

6563
class ModelABC(ABC, torch.nn.Module):
@@ -73,7 +71,9 @@ def __init__(self: ModelABC) -> None:
7371

7472
@abstractmethod
7573
# This is generic abc, else pylint will complain
76-
def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None:
74+
def forward(
75+
self: ModelABC, *args: tuple[Any, ...], **kwargs: dict
76+
) -> None | torch.Tensor:
7777
"""Torch method, this contains logic for using layers defined in init."""
7878
... # pragma: no cover
7979

@@ -172,27 +172,38 @@ def postproc_func(self: ModelABC, func: Callable) -> None:
172172
else:
173173
self._postproc = func
174174

175-
def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module:
175+
def to( # type: ignore[override]
176+
self: ModelABC,
177+
device: str = "cpu",
178+
dtype: torch.dtype | None = None,
179+
*,
180+
non_blocking: bool = False,
181+
) -> ModelABC | torch.nn.DataParallel[ModelABC]:
176182
"""Transfers model to cpu/gpu.
177183
178184
Args:
179185
self (ModelABC):
180186
PyTorch defined model.
181187
device (str):
182188
Transfers model to the specified device. Default is "cpu".
189+
dtype (:class:`torch.dtype`): the desired floating point or complex dtype of
190+
the parameters and buffers in this module.
191+
non_blocking (bool): When set, it tries to convert/move asynchronously
192+
with respect to the host if possible, e.g., moving CPU Tensors with
193+
pinned memory to CUDA devices.
183194
184195
Returns:
185-
torch.nn.Module:
196+
torch.nn.Module | torch.nn.DataParallel:
186197
The model after being moved to cpu/gpu.
187198
188199
"""
189-
device = torch_device(device)
190-
model = super().to(device)
200+
torch_device = torch.device(device)
201+
model = super().to(torch_device, dtype=dtype, non_blocking=non_blocking)
191202

192203
# If target device istorch.cuda and more
193204
# than one GPU is available, use DataParallel
194-
if device.type == "cuda" and torch.cuda.device_count() > 1:
195-
model = torch.nn.DataParallel(model) # pragma: no cover
205+
if torch_device.type == "cuda" and torch.cuda.device_count() > 1:
206+
return torch.nn.DataParallel(model) # pragma: no cover
196207

197208
return model
198209

0 commit comments

Comments
 (0)