Skip to content

Commit 53352b7

Browse files
BUG: replace cache_dir wirh local_dir for specify the saving directory
1 parent 3081922 commit 53352b7

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

tests/test_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,21 +1618,25 @@ def test_from_multi_head_dat_type_dict(tmp_path: Path) -> None:
16181618

16191619
def test_fetch_pretrained_weights(tmp_path: Path) -> None:
16201620
"""Test fetching pretrained weights for a model."""
1621-
file_path = tmp_path / "test_fetch_pretrained_weights.pth"
1621+
model_name = "mobilenet_v3_small-pcam"
1622+
file_path = tmp_path / f"{model_name}.pth"
16221623
if file_path.exists():
16231624
file_path.unlink()
16241625

1625-
fetch_pretrained_weights(model_name="mobilenet_v3_small-pcam", save_path=file_path)
1626+
print("tmp_path: ", tmp_path)
1627+
mosi = fetch_pretrained_weights(model_name="mobilenet_v3_small-pcam", save_path=tmp_path)
1628+
print("returned path: ", mosi)
1629+
print("desired path: ", file_path)
1630+
print(file_path.exists())
16261631
assert file_path.exists()
16271632
assert file_path.stat().st_size > 0
16281633
file_path.unlink()
16291634

16301635
with pytest.raises(ValueError, match="does not exist"):
1631-
fetch_pretrained_weights("abc", file_path)
1636+
fetch_pretrained_weights("abc", tmp_path)
16321637

16331638
# Test save_path is str
1634-
file_path_str = str(file_path)
1635-
file_path = fetch_pretrained_weights("mobilenet_v3_small-pcam", file_path_str)
1639+
file_path = fetch_pretrained_weights(model_name, str(tmp_path))
16361640
assert Path(file_path).exists()
16371641
assert Path(file_path).stat().st_size > 0
16381642

tiatoolbox/models/architecture/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ def fetch_pretrained_weights(
3333
Refer to `::py::meth:get_pretrained_model` for all supported
3434
model names.
3535
save_path (str | Path):
36-
Path to save the weight of the
37-
corresponding `model_name`.
36+
Path to the directory in which the pretrained weight will be cached.
3837
overwrite (bool):
3938
Overwrite existing downloaded weights (force downloading).
4039
@@ -60,7 +59,7 @@ def fetch_pretrained_weights(
6059
return hf_hub_download(
6160
repo_id="TIACentre/TIAToolbox_pretrained_weights",
6261
filename=file_name,
63-
cache_dir=cache_dir,
62+
local_dir=cache_dir,
6463
force_download=overwrite,
6564
)
6665

tiatoolbox/models/architecture/vanilla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def _get_architecture(
3333
Name of the architecture (e.g. 'resnet50', 'alexnet').
3434
weights (str or WeightsEnum):
3535
Pretrained torchvision model weights to use (get_model_weights).
36-
Defaults to "DEFAULT".
36+
Defaults is None to avoid downloading ImageNet weights.
37+
To initiate the models with ImageNet weights, use "DEFAULT".
3738
**kwargs (dict):
3839
Key-word arguments.
3940

0 commit comments

Comments
 (0)