Skip to content

Commit 35ffd61

Browse files
committed
fix typing
1 parent b580e1e commit 35ffd61

File tree

3 files changed

+32
-16
lines changed

3 files changed

+32
-16
lines changed

.github/workflows/mypy-type-check.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,7 @@ jobs:
4646
tiatoolbox/tools \
4747
tiatoolbox/data \
4848
tiatoolbox/annotation \
49-
tiatoolbox/cli/common.py
49+
tiatoolbox/cli/common.py \
50+
tiatoolbox/models/__init__.py \
51+
tiatoolbox/models/models_abc.py \
52+
tiatoolbox/models/architecture/__init__.py \

tiatoolbox/models/architecture/__init__.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33
from __future__ import annotations
44

55
import os
6+
from pathlib import Path
67
from pydoc import locate
7-
from typing import TYPE_CHECKING, Optional, Union
8+
from typing import TYPE_CHECKING, Optional, Union, cast
89

910
import torch
1011

1112
from tiatoolbox import rcParam
1213
from tiatoolbox.models.dataset.classification import predefined_preproc_func
14+
from tiatoolbox.models.models_abc import ModelABC
1315
from tiatoolbox.utils import download_data
1416

1517
if TYPE_CHECKING: # pragma: no cover
16-
from pathlib import Path
17-
1818
from tiatoolbox.models.models_abc import IOConfigABC
1919

2020

@@ -53,10 +53,13 @@ def fetch_pretrained_weights(
5353

5454
if save_path is None:
5555
file_name = info["url"].split("/")[-1]
56-
save_path = rcParam["TIATOOLBOX_HOME"] / "models" / file_name
56+
processed_save_path = rcParam["TIATOOLBOX_HOME"] / "models" / file_name
57+
58+
if type(save_path) is str:
59+
processed_save_path = Path(save_path)
5760

58-
download_data(info["url"], save_path=save_path, overwrite=overwrite)
59-
return save_path
61+
download_data(info["url"], save_path=processed_save_path, overwrite=overwrite)
62+
return processed_save_path
6063

6164

6265
def get_pretrained_model(
@@ -129,9 +132,15 @@ def get_pretrained_model(
129132
info = PRETRAINED_INFO[pretrained_model]
130133

131134
arch_info = info["architecture"]
132-
creator = locate(f"tiatoolbox.models.architecture.{arch_info['class']}")
133-
134-
model = creator(**arch_info["kwargs"])
135+
model_class_info = arch_info["class"]
136+
model_module_name = str(".".join(model_class_info.split(".")[:-1]))
137+
model_name = str(model_class_info.split(".")[-1])
138+
139+
# Import module containing required model class
140+
arch_module = locate(f"tiatoolbox.models.architecture.{model_module_name}")
141+
# Get model class form module
142+
model_class = getattr(arch_module, model_name)
143+
model = model_class(**arch_info["kwargs"])
135144
# TODO(TBC): Dictionary of dataset specific or transformation? # noqa: FIX002,TD003
136145
if "dataset" in info:
137146
# ! this is a hack currently, need another PR to clean up
@@ -152,7 +161,12 @@ def get_pretrained_model(
152161
# !
153162

154163
io_info = info["ioconfig"]
155-
creator = locate(f"tiatoolbox.models.engine.{io_info['class']}")
164+
io_class_info = io_info["class"]
165+
io_module_name = str(".".join(io_class_info.split(".")[:-1]))
166+
io_class_name = str(io_class_info.split(".")[-1])
167+
168+
engine_module = locate(f"tiatoolbox.models.engine.{io_module_name}")
169+
engine_class = getattr(engine_module, io_class_name)
156170

157-
iostate = creator(**io_info["kwargs"])
171+
iostate = engine_class(**io_info["kwargs"])
158172
return model, iostate

tiatoolbox/models/models_abc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import torch
99
import torch._dynamo
10-
from torch import device as torch_device
1110

1211
torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001
1312

@@ -189,12 +188,12 @@ def to(self: ModelABC, device: str = "cpu") -> torch.nn.Module:
189188
The model after being moved to cpu/gpu.
190189
191190
"""
192-
device = torch_device(device)
193-
model = super().to(device)
191+
torch_device = torch.device(device)
192+
model = super().to(torch_device)
194193

195194
# If target device istorch.cuda and more
196195
# than one GPU is available, use DataParallel
197-
if device.type == "cuda" and torch.cuda.device_count() > 1:
196+
if torch_device.type == "cuda" and torch.cuda.device_count() > 1:
198197
model = torch.nn.DataParallel(model) # pragma: no cover
199198

200199
return model

0 commit comments

Comments
 (0)