33from __future__ import annotations
44
55import os
6+ from pathlib import Path
67from pydoc import locate
7- from typing import TYPE_CHECKING , Optional , Union
8+ from typing import TYPE_CHECKING , Optional , Union , cast
89
910import torch
1011
1112from tiatoolbox import rcParam
1213from tiatoolbox .models .dataset .classification import predefined_preproc_func
14+ from tiatoolbox .models .models_abc import ModelABC
1315from tiatoolbox .utils import download_data
1416
1517if TYPE_CHECKING : # pragma: no cover
16- from pathlib import Path
17-
1818 from tiatoolbox .models .models_abc import IOConfigABC
1919
2020
@@ -53,10 +53,13 @@ def fetch_pretrained_weights(
5353
5454 if save_path is None :
5555 file_name = info ["url" ].split ("/" )[- 1 ]
56- save_path = rcParam ["TIATOOLBOX_HOME" ] / "models" / file_name
56+ processed_save_path = rcParam ["TIATOOLBOX_HOME" ] / "models" / file_name
57+
58+ if type (save_path ) is str :
59+ processed_save_path = Path (save_path )
5760
58- download_data (info ["url" ], save_path = save_path , overwrite = overwrite )
59- return save_path
61+ download_data (info ["url" ], save_path = processed_save_path , overwrite = overwrite )
62+ return processed_save_path
6063
6164
6265def get_pretrained_model (
@@ -129,9 +132,15 @@ def get_pretrained_model(
129132 info = PRETRAINED_INFO [pretrained_model ]
130133
131134 arch_info = info ["architecture" ]
132- creator = locate (f"tiatoolbox.models.architecture.{ arch_info ['class' ]} " )
133-
134- model = creator (** arch_info ["kwargs" ])
135+ model_class_info = arch_info ["class" ]
136+ model_module_name = str ("." .join (model_class_info .split ("." )[:- 1 ]))
137+ model_name = str (model_class_info .split ("." )[- 1 ])
138+
139+ # Import module containing required model class
140+ arch_module = locate (f"tiatoolbox.models.architecture.{ model_module_name } " )
141+ # Get model class form module
142+ model_class = getattr (arch_module , model_name )
143+ model = model_class (** arch_info ["kwargs" ])
135144 # TODO(TBC): Dictionary of dataset specific or transformation? # noqa: FIX002,TD003
136145 if "dataset" in info :
137146 # ! this is a hack currently, need another PR to clean up
@@ -152,7 +161,12 @@ def get_pretrained_model(
152161 # !
153162
154163 io_info = info ["ioconfig" ]
155- creator = locate (f"tiatoolbox.models.engine.{ io_info ['class' ]} " )
164+ io_class_info = io_info ["class" ]
165+ io_module_name = str ("." .join (io_class_info .split ("." )[:- 1 ]))
166+ io_class_name = str (io_class_info .split ("." )[- 1 ])
167+
168+ engine_module = locate (f"tiatoolbox.models.engine.{ io_module_name } " )
169+ engine_class = getattr (engine_module , io_class_name )
156170
157- iostate = creator (** io_info ["kwargs" ])
171+ iostate = engine_class (** io_info ["kwargs" ])
158172 return model , iostate
0 commit comments