Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
55871dd
ENH: Use HF Hub for model weights and avoid imagenet weight downloading
mostafajahanifar Jul 11, 2025
b9dda0b
MAIN: no more download_data for model weights
mostafajahanifar Jul 11, 2025
c214608
MAINT: add huggingface_hub to requirements
mostafajahanifar Jul 11, 2025
ba874d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2025
54ded49
ENH: adding overwrite option to force download files
mostafajahanifar Jul 11, 2025
3081922
Merge branch 'enhance-hf-weights' of https://github.com/TissueImageAn…
mostafajahanifar Jul 11, 2025
53352b7
BUG: replace cache_dir wirh local_dir for specify the saving directory
mostafajahanifar Jul 14, 2025
59d6e93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 14, 2025
abe6f3d
MAIN: pin the huggingface_hub version
mostafajahanifar Jul 14, 2025
9e18775
Merge branch 'enhance-hf-weights' of https://github.com/TissueImageAn…
mostafajahanifar Jul 14, 2025
ad222c9
BUG: fix type checking
mostafajahanifar Jul 18, 2025
0b52138
Merge branch 'develop' into enhance-hf-weights
shaneahmed Jul 25, 2025
41ba556
Merge branch 'develop' into enhance-hf-weights
shaneahmed Aug 5, 2025
6721567
Merge branch 'develop' into enhance-hf-weights
shaneahmed Aug 8, 2025
c53b7b2
Merge branch 'develop' into enhance-hf-weights
shaneahmed Sep 2, 2025
630ee1c
Merge branch 'develop' into enhance-hf-weights
shaneahmed Oct 2, 2025
fb1ed53
Merge branch 'develop' into enhance-hf-weights
shaneahmed Oct 10, 2025
caff2ea
address comments
Jiaqi-Lv Oct 10, 2025
30f3492
Merge branch 'develop' into enhance-hf-weights
shaneahmed Oct 16, 2025
7cd2357
update docstring
Jiaqi-Lv Oct 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ filelock>=3.9.0
flask>=2.2.2
flask-cors>=4.0.0
glymur>=0.12.7, < 0.14 # 0.14 is not compatible with python3.9
huggingface_hub # do we need to pin it to 0.33.3?
imagecodecs>=2022.9.26
joblib>=1.1.1
jupyterlab>=3.5.2
Expand Down
17 changes: 10 additions & 7 deletions tiatoolbox/models/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from typing import TYPE_CHECKING

import torch
from huggingface_hub import hf_hub_download

from tiatoolbox import rcParam
from tiatoolbox.models.dataset.classification import predefined_preproc_func
from tiatoolbox.utils import download_data

if TYPE_CHECKING: # pragma: no cover
from tiatoolbox.models.models_abc import IOConfigABC
Expand Down Expand Up @@ -49,16 +49,19 @@ def fetch_pretrained_weights(

info = PRETRAINED_INFO[model_name]

file_name = info["url"].split("/")[-1]
if save_path is None:
file_name = info["url"].split("/")[-1]
processed_save_path = rcParam["TIATOOLBOX_HOME"] / "models" / file_name
cache_dir = rcParam["TIATOOLBOX_HOME"] / "models"
elif type(save_path) is str:
processed_save_path = Path(save_path)
cache_dir = Path(save_path)
else:
processed_save_path = save_path
cache_dir = save_path

download_data(info["url"], save_path=processed_save_path, overwrite=overwrite)
return processed_save_path
return hf_hub_download(
repo_id="TIACentre/TIAToolbox_pretrained_weights",
filename=file_name,
cache_dir=cache_dir,
)


def get_pretrained_model(
Expand Down
11 changes: 7 additions & 4 deletions tiatoolbox/models/architecture/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def _get_architecture(
arch_name: str,
weights: str or WeightsEnum = "DEFAULT",
weights: str or WeightsEnum = None,
**kwargs: dict,
) -> list[nn.Sequential, ...] | nn.Sequential:
"""Retrieve a CNN model architecture.
Expand Down Expand Up @@ -75,17 +75,18 @@ def _get_architecture(
raise ValueError(msg)

creator = backbone_dict[arch_name]
model = creator(weights=weights, **kwargs)
if "inception_v3" in arch_name or "googlenet" in arch_name:
model = creator(weights=weights, aux_logits=False, num_classes=1000)
return nn.Sequential(*list(model.children())[:-3])

model = creator(weights=weights, **kwargs)
# Unroll all the definition and strip off the final GAP and FCN
if "resnet" in arch_name or "resnext" in arch_name:
return nn.Sequential(*list(model.children())[:-2])
if "densenet" in arch_name:
return model.features
if "alexnet" in arch_name:
return model.features
if "inception_v3" in arch_name or "googlenet" in arch_name:
return nn.Sequential(*list(model.children())[:-3])

return model.features

Expand Down Expand Up @@ -297,6 +298,7 @@ def __init__(self: CNNModel, backbone: str, num_classes: int = 1) -> None:
super().__init__()
self.num_classes = num_classes

# set num_classes to 100 to avoid downloading weights
self.feat_extract = _get_architecture(backbone)
self.pool = nn.AdaptiveAvgPool2d((1, 1))

Expand Down Expand Up @@ -547,6 +549,7 @@ class CNNBackbone(ModelABC):
def __init__(self: CNNBackbone, backbone: str) -> None:
"""Initialize :class:`CNNBackbone`."""
super().__init__()
# set num_classes=1000 to avoid downloading weights
self.feat_extract = _get_architecture(backbone)
self.pool = nn.AdaptiveAvgPool2d((1, 1))

Expand Down
Loading