Skip to content

Commit d24dfbe

Browse files
committed
✨ Add support for architectures in timm.list_models()
1 parent e015933 commit d24dfbe

File tree

5 files changed

+120
-125
lines changed

5 files changed

+120
-125
lines changed

tests/engines/test_feature_extractor.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
import pytest
99
import torch
1010
import zarr
11-
from click.testing import CliRunner
1211

13-
from tiatoolbox import cli
1412
from tiatoolbox.models import IOSegmentorConfig
1513
from tiatoolbox.models.architecture.vanilla import CNNBackbone, TimmBackbone
1614
from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor
@@ -93,7 +91,13 @@ def test_feature_extractor_wsi(remote_sample: Callable, track_tmp_path: Path) ->
9391

9492

9593
@pytest.mark.parametrize(
96-
"model", [CNNBackbone("resnet50"), TimmBackbone("efficientnet_b0", pretrained=True)]
94+
"model",
95+
[
96+
CNNBackbone("resnet18"),
97+
TimmBackbone("efficientnet_b0", pretrained=True),
98+
"resnet18",
99+
"efficientnet_b0",
100+
],
97101
)
98102
def test_full_inference(
99103
remote_sample: Callable, track_tmp_path: Path, model: Callable
@@ -210,23 +214,3 @@ def test_multi_gpu_feature_extraction(
210214
# -------------------------------------------------------------------------------------
211215
# Command Line Interface
212216
# -------------------------------------------------------------------------------------
213-
214-
215-
def test_cli_model_single_file(sample_svs: Path, track_tmp_path: Path) -> None:
216-
"""Test for feature extractor CLI single file."""
217-
runner = CliRunner()
218-
models_wsi_result = runner.invoke(
219-
cli.main,
220-
[
221-
"deep-feature-extractor",
222-
"--img-input",
223-
str(sample_svs),
224-
"--patch-mode",
225-
"False",
226-
"--output-path",
227-
str(track_tmp_path / "output"),
228-
],
229-
)
230-
231-
assert models_wsi_result.exit_code == 0
232-
assert (track_tmp_path / "output" / (sample_svs.stem + ".zarr")).exists()

tiatoolbox/cli/deep_feature_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
@cli_file_type(
3636
default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs",
3737
)
38-
@cli_model(default="fcn-tissue_mask")
38+
@cli_model(default="efficientnet_b0")
3939
@cli_weights()
4040
@cli_device(default="cpu")
4141
@cli_batch_size(default=1)

tiatoolbox/models/architecture/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66
from pydoc import locate
77
from typing import TYPE_CHECKING
88

9+
import timm
910
from huggingface_hub import hf_hub_download
1011

1112
from tiatoolbox import rcParam
1213
from tiatoolbox.models.dataset.classification import predefined_preproc_func
1314
from tiatoolbox.models.models_abc import load_torch_model
1415

16+
from .vanilla import CNNBackbone, TimmBackbone, timm_arch_dict, torch_cnn_backbone_dict
17+
1518
if TYPE_CHECKING: # pragma: no cover
1619
import torch
1720

@@ -69,7 +72,7 @@ def get_pretrained_model(
6972
pretrained_weights: str | Path | None = None,
7073
*,
7174
overwrite: bool = False,
72-
) -> tuple[torch.nn.Module, ModelIOConfigABC]:
75+
) -> tuple[torch.nn.Module, ModelIOConfigABC | None]:
7376
"""Load a predefined PyTorch model with the appropriate pretrained weights.
7477
7578
Args:
@@ -127,6 +130,12 @@ def get_pretrained_model(
127130
msg = "pretrained_model must be a string."
128131
raise TypeError(msg)
129132

133+
if pretrained_model in torch_cnn_backbone_dict:
134+
return CNNBackbone(pretrained_model), None
135+
136+
if pretrained_model in timm.list_models():
137+
return TimmBackbone(pretrained_model, pretrained=True), None
138+
130139
if pretrained_model not in PRETRAINED_INFO:
131140
msg = f"Pretrained model `{pretrained_model}` does not exist."
132141
raise ValueError(msg)

tiatoolbox/models/architecture/vanilla.py

Lines changed: 99 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,96 @@
1717
import numpy as np
1818
from torchvision.models import WeightsEnum
1919

20+
torch_cnn_backbone_dict = {
21+
"alexnet": torch_models.alexnet,
22+
"resnet18": torch_models.resnet18,
23+
"resnet34": torch_models.resnet34,
24+
"resnet50": torch_models.resnet50,
25+
"resnet101": torch_models.resnet101,
26+
"resnext50_32x4d": torch_models.resnext50_32x4d,
27+
"resnext101_32x8d": torch_models.resnext101_32x8d,
28+
"wide_resnet50_2": torch_models.wide_resnet50_2,
29+
"wide_resnet101_2": torch_models.wide_resnet101_2,
30+
"densenet121": torch_models.densenet121,
31+
"densenet161": torch_models.densenet161,
32+
"densenet169": torch_models.densenet169,
33+
"densenet201": torch_models.densenet201,
34+
"inception_v3": torch_models.inception_v3,
35+
"googlenet": torch_models.googlenet,
36+
"mobilenet_v2": torch_models.mobilenet_v2,
37+
"mobilenet_v3_large": torch_models.mobilenet_v3_large,
38+
"mobilenet_v3_small": torch_models.mobilenet_v3_small,
39+
}
40+
41+
timm_arch_dict = {
42+
# UNI tile encoder: https://huggingface.co/MahmoodLab/UNI
43+
"UNI": {
44+
"model": "hf-hub:MahmoodLab/UNI",
45+
"init_values": 1e-5,
46+
"dynamic_img_size": True,
47+
},
48+
# Prov-GigaPath tile encoder: https://huggingface.co/prov-gigapath/prov-gigapath
49+
"prov-gigapath": {"model": "hf_hub:prov-gigapath/prov-gigapath"},
50+
# H-Optimus-0 tile encoder: https://huggingface.co/bioptimus/H-optimus-0
51+
"H-optimus-0": {
52+
"model": "hf-hub:bioptimus/H-optimus-0",
53+
"init_values": 1e-5,
54+
"dynamic_img_size": False,
55+
},
56+
# H-Optimus-1 tile encoder: https://huggingface.co/bioptimus/H-optimus-1
57+
"H-optimus-1": {
58+
"model": "hf-hub:bioptimus/H-optimus-1",
59+
"init_values": 1e-5,
60+
"dynamic_img_size": False,
61+
},
62+
# HO-mini tile encoder: https://huggingface.co/bioptimus/H0-mini
63+
"H0-mini": {
64+
"model": "hf-hub:bioptimus/H0-mini",
65+
"init_values": 1e-5,
66+
"dynamic_img_size": False,
67+
"mlp_layer": timm.layers.SwiGLUPacked,
68+
"act_layer": torch.nn.SiLU,
69+
},
70+
# UNI2-h tile encoder: https://huggingface.co/MahmoodLab/UNI2-h
71+
"UNI2": {
72+
"model": "hf-hub:MahmoodLab/UNI2-h",
73+
"img_size": 224,
74+
"patch_size": 14,
75+
"depth": 24,
76+
"num_heads": 24,
77+
"init_values": 1e-5,
78+
"embed_dim": 1536,
79+
"mlp_ratio": 2.66667 * 2,
80+
"num_classes": 0,
81+
"no_embed_class": True,
82+
"mlp_layer": timm.layers.SwiGLUPacked,
83+
"act_layer": torch.nn.SiLU,
84+
"reg_tokens": 8,
85+
"dynamic_img_size": True,
86+
},
87+
# Virchow tile encoder: https://huggingface.co/paige-ai/Virchow
88+
"Virchow": {
89+
"model": "hf_hub:paige-ai/Virchow",
90+
"mlp_layer": SwiGLUPacked,
91+
"act_layer": torch.nn.SiLU,
92+
},
93+
# Virchow2 tile encoder: https://huggingface.co/paige-ai/Virchow2
94+
"Virchow2": {
95+
"model": "hf_hub:paige-ai/Virchow2",
96+
"mlp_layer": SwiGLUPacked,
97+
"act_layer": torch.nn.SiLU,
98+
},
99+
# Kaiko tile encoder:
100+
# https://huggingface.co/1aurent/vit_large_patch14_reg4_224.kaiko_ai_towards_large_pathology_fms
101+
"kaiko": {
102+
"model": (
103+
"hf_hub:1aurent/"
104+
"vit_large_patch14_reg4_224.kaiko_ai_towards_large_pathology_fms"
105+
),
106+
"dynamic_img_size": True,
107+
},
108+
}
109+
20110

21111
def _get_architecture(
22112
arch_name: str,
@@ -52,31 +142,11 @@ def _get_architecture(
52142
>>> print(model)
53143
54144
"""
55-
backbone_dict = {
56-
"alexnet": torch_models.alexnet,
57-
"resnet18": torch_models.resnet18,
58-
"resnet34": torch_models.resnet34,
59-
"resnet50": torch_models.resnet50,
60-
"resnet101": torch_models.resnet101,
61-
"resnext50_32x4d": torch_models.resnext50_32x4d,
62-
"resnext101_32x8d": torch_models.resnext101_32x8d,
63-
"wide_resnet50_2": torch_models.wide_resnet50_2,
64-
"wide_resnet101_2": torch_models.wide_resnet101_2,
65-
"densenet121": torch_models.densenet121,
66-
"densenet161": torch_models.densenet161,
67-
"densenet169": torch_models.densenet169,
68-
"densenet201": torch_models.densenet201,
69-
"inception_v3": torch_models.inception_v3,
70-
"googlenet": torch_models.googlenet,
71-
"mobilenet_v2": torch_models.mobilenet_v2,
72-
"mobilenet_v3_large": torch_models.mobilenet_v3_large,
73-
"mobilenet_v3_small": torch_models.mobilenet_v3_small,
74-
}
75-
if arch_name not in backbone_dict:
145+
if arch_name not in torch_cnn_backbone_dict:
76146
msg = f"Backbone `{arch_name}` is not supported."
77147
raise ValueError(msg)
78148

79-
creator = backbone_dict[arch_name]
149+
creator = torch_cnn_backbone_dict[arch_name]
80150
if "inception_v3" in arch_name or "googlenet" in arch_name:
81151
model = creator(weights=weights, aux_logits=False, num_classes=1000)
82152
return nn.Sequential(*list(model.children())[:-3])
@@ -123,87 +193,18 @@ def _get_timm_architecture(
123193
>>> print(model)
124194
125195
"""
126-
if arch_name in [f"efficientnet_b{i}" for i in range(8)]:
127-
model = timm.create_model(arch_name, pretrained=pretrained)
128-
return nn.Sequential(*list(model.children())[:-1])
129-
130-
arch_map = {
131-
# UNI tile encoder: https://huggingface.co/MahmoodLab/UNI
132-
"UNI": {
133-
"model": "hf-hub:MahmoodLab/UNI",
134-
"init_values": 1e-5,
135-
"dynamic_img_size": True,
136-
},
137-
# Prov-GigaPath tile encoder: https://huggingface.co/prov-gigapath/prov-gigapath
138-
"prov-gigapath": {"model": "hf_hub:prov-gigapath/prov-gigapath"},
139-
# H-Optimus-0 tile encoder: https://huggingface.co/bioptimus/H-optimus-0
140-
"H-optimus-0": {
141-
"model": "hf-hub:bioptimus/H-optimus-0",
142-
"init_values": 1e-5,
143-
"dynamic_img_size": False,
144-
},
145-
# H-Optimus-1 tile encoder: https://huggingface.co/bioptimus/H-optimus-1
146-
"H-optimus-1": {
147-
"model": "hf-hub:bioptimus/H-optimus-1",
148-
"init_values": 1e-5,
149-
"dynamic_img_size": False,
150-
},
151-
# HO-mini tile encoder: https://huggingface.co/bioptimus/H0-mini
152-
"H0-mini": {
153-
"model": "hf-hub:bioptimus/H0-mini",
154-
"init_values": 1e-5,
155-
"dynamic_img_size": False,
156-
"mlp_layer": timm.layers.SwiGLUPacked,
157-
"act_layer": torch.nn.SiLU,
158-
},
159-
# UNI2-h tile encoder: https://huggingface.co/MahmoodLab/UNI2-h
160-
"UNI2": {
161-
"model": "hf-hub:MahmoodLab/UNI2-h",
162-
"img_size": 224,
163-
"patch_size": 14,
164-
"depth": 24,
165-
"num_heads": 24,
166-
"init_values": 1e-5,
167-
"embed_dim": 1536,
168-
"mlp_ratio": 2.66667 * 2,
169-
"num_classes": 0,
170-
"no_embed_class": True,
171-
"mlp_layer": timm.layers.SwiGLUPacked,
172-
"act_layer": torch.nn.SiLU,
173-
"reg_tokens": 8,
174-
"dynamic_img_size": True,
175-
},
176-
# Virchow tile encoder: https://huggingface.co/paige-ai/Virchow
177-
"Virchow": {
178-
"model": "hf_hub:paige-ai/Virchow",
179-
"mlp_layer": SwiGLUPacked,
180-
"act_layer": torch.nn.SiLU,
181-
},
182-
# Virchow2 tile encoder: https://huggingface.co/paige-ai/Virchow2
183-
"Virchow2": {
184-
"model": "hf_hub:paige-ai/Virchow2",
185-
"mlp_layer": SwiGLUPacked,
186-
"act_layer": torch.nn.SiLU,
187-
},
188-
# Kaiko tile encoder:
189-
# https://huggingface.co/1aurent/vit_large_patch14_reg4_224.kaiko_ai_towards_large_pathology_fms
190-
"kaiko": {
191-
"model": (
192-
"hf_hub:1aurent/"
193-
"vit_large_patch14_reg4_224.kaiko_ai_towards_large_pathology_fms"
194-
),
195-
"dynamic_img_size": True,
196-
},
197-
}
198-
199-
if arch_name in arch_map: # pragma: no cover
196+
if arch_name in timm_arch_dict: # pragma: no cover
200197
# Coverage skipped timm API is tested using efficient U-Net.
201198
return timm.create_model(
202-
arch_map[arch_name].pop("model"),
199+
timm_arch_dict[arch_name].pop("model"),
203200
pretrained=pretrained,
204-
**arch_map[arch_name],
201+
**timm_arch_dict[arch_name],
205202
)
206203

204+
if arch_name in timm.list_models():
205+
model = timm.create_model(arch_name, pretrained=pretrained)
206+
return nn.Sequential(*list(model.children())[:-1])
207+
207208
msg = f"Backbone {arch_name} not supported. "
208209
raise ValueError(msg)
209210

tiatoolbox/models/engine/deep_feature_extractor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
Example:
3030
--------
3131
>>> from tiatoolbox.models.engine.deep_feature_extractor import DeepFeatureExtractor
32-
>>> extractor = DeepFeatureExtractor(model="resnet50-kather100k")
32+
>>> extractor = DeepFeatureExtractor(model="efficientnet_b0")
3333
>>> wsis = ["slide1.svs", "slide2.svs"]
3434
>>> output = extractor.run(wsis, patch_mode=False, output_type="zarr")
3535
>>> print(output)
36-
'/path/to/output.zarr'
36+
... '/path/to/output.zarr'
3737
3838
"""
3939

@@ -731,6 +731,7 @@ def run(
731731
Raises:
732732
ValueError:
733733
If `output_type` is not "zarr" or "dict".
734+
734735
"""
735736
# return_probabilities is always True for FeatureExtractor.
736737
kwargs["return_probabilities"] = True

0 commit comments

Comments
 (0)