Skip to content

Commit 2b584e0

Browse files
authored
Update vanilla.py
Add extra models including Virchow, Virchow2
1 parent 7021108 commit 2b584e0

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tiatoolbox/models/architecture/vanilla.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torchvision.models as torch_models
1111
from torch import nn
12+
from timm.layers import SwiGLUPacked
1213

1314
from tiatoolbox.models.models_abc import ModelABC
1415

@@ -156,6 +157,26 @@ def _get_timm_architecture(
156157
**timm_kwargs,
157158
)
158159

160+
if arch_name == "Virchow": # pragma: no cover
161+
# Virchow tile encoder: https://huggingface.co/paige-ai/Virchow
162+
# Coverage skipped timm API is tested using efficient U-Net.
163+
return timm.create_model(
164+
"hf_hub:paige-ai/Virchow",
165+
pretrained=pretrained,
166+
mlp_layer=SwiGLUPacked,
167+
act_layer=torch.nn.SiLU
168+
)
169+
170+
if arch_name == "Virchow2": # pragma: no cover
171+
# Virchow2 tile encoder: https://huggingface.co/paige-ai/Virchow2
172+
# Coverage skipped timm API is tested using efficient U-Net.
173+
return timm.create_model(
174+
"hf_hub:paige-ai/Virchow2",
175+
pretrained=pretrained,
176+
mlp_layer=SwiGLUPacked,
177+
act_layer=torch.nn.SiLU
178+
)
179+
159180
msg = f"Backbone {arch_name} not supported. "
160181
raise ValueError(msg)
161182

@@ -297,6 +318,8 @@ class TimmModel(ModelABC):
297318
- "UNI"
298319
- "prov-gigapath"
299320
- "UNI2"
321+
- "Virchow"
322+
- "Virchow2"
300323
num_classes (int):
301324
Number of classes output by model.
302325
pretrained (bool, keyword-only):
@@ -482,6 +505,8 @@ class TimmBackbone(ModelABC):
482505
- "UNI"
483506
- "prov-gigapath"
484507
- "UNI2"
508+
- "Virchow"
509+
- "Virchow2"
485510
pretrained (bool, keyword-only):
486511
Whether to load pretrained weights.
487512

0 commit comments

Comments
 (0)