Skip to content

Commit d173a8c

Browse files
mostafajahanifarJiaqi-Lvshaneahmed
authored
🤗 HuggingFace for pretrained model weights hosting (#945)
This PR migrates model weight hosting from TIA servers to HuggingFace for improved distribution and caching. The changes introduce HuggingFace Hub integration while modifying default behavior to avoid automatic ImageNet weight downloads. Key changes: - Replace custom download logic with [HuggingFace](https://huggingface.co/TIACentre/TIAToolbox_pretrained_weights) for pretrained model weights - Update model architecture handling for Inception and GoogleNet models - Use `huggingface_hub` for downloading/caching pretrained weights - Use `None` model `pretrained` argument during model initialization, to avoid redundant downloading of ImageNet weights --------- Co-authored-by: Jiaqi Lv <[email protected]> Co-authored-by: Shan E Ahmed Raza <[email protected]>
1 parent e1eb5bd commit d173a8c

File tree

5 files changed

+90
-80
lines changed

5 files changed

+90
-80
lines changed

requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ filelock>=3.9.0
99
flask>=2.2.2
1010
flask-cors>=4.0.0
1111
glymur>=0.12.7
12+
huggingface_hub>=0.33.3
1213
imagecodecs>=2022.9.26
1314
joblib>=1.1.1
1415
jupyterlab>=3.5.2

tests/test_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,21 +1618,24 @@ def test_from_multi_head_dat_type_dict(tmp_path: Path) -> None:
16181618

16191619
def test_fetch_pretrained_weights(tmp_path: Path) -> None:
16201620
"""Test fetching pretrained weights for a model."""
1621-
file_path = tmp_path / "test_fetch_pretrained_weights.pth"
1621+
model_name = "mobilenet_v3_small-pcam"
1622+
file_path = tmp_path / f"{model_name}.pth"
16221623
if file_path.exists():
16231624
file_path.unlink()
16241625

1625-
fetch_pretrained_weights(model_name="mobilenet_v3_small-pcam", save_path=file_path)
1626+
_ = fetch_pretrained_weights(
1627+
model_name="mobilenet_v3_small-pcam", save_path=tmp_path
1628+
)
1629+
16261630
assert file_path.exists()
16271631
assert file_path.stat().st_size > 0
16281632
file_path.unlink()
16291633

16301634
with pytest.raises(ValueError, match="does not exist"):
1631-
fetch_pretrained_weights("abc", file_path)
1635+
fetch_pretrained_weights("abc", tmp_path)
16321636

16331637
# Test save_path is str
1634-
file_path_str = str(file_path)
1635-
file_path = fetch_pretrained_weights("mobilenet_v3_small-pcam", file_path_str)
1638+
file_path = fetch_pretrained_weights(model_name, str(tmp_path))
16361639
assert Path(file_path).exists()
16371640
assert Path(file_path).stat().st_size > 0
16381641

0 commit comments

Comments
 (0)