|
9 | 9 | import torch |
10 | 10 | import torchvision.models as torch_models |
11 | 11 | from torch import nn |
| 12 | +from timm.layers import SwiGLUPacked |
12 | 13 |
|
13 | 14 | from tiatoolbox.models.models_abc import ModelABC |
14 | 15 |
|
@@ -156,6 +157,26 @@ def _get_timm_architecture( |
156 | 157 | **timm_kwargs, |
157 | 158 | ) |
158 | 159 |
|
| 160 | + if arch_name == "Virchow": # pragma: no cover |
| 161 | + # Virchow tile encoder: https://huggingface.co/paige-ai/Virchow |
| 162 | + # Coverage skipped timm API is tested using efficient U-Net. |
| 163 | + return timm.create_model( |
| 164 | + "hf_hub:paige-ai/Virchow", |
| 165 | + pretrained=pretrained, |
| 166 | + mlp_layer=SwiGLUPacked, |
| 167 | + act_layer=torch.nn.SiLU |
| 168 | + ) |
| 169 | + |
| 170 | + if arch_name == "Virchow2": # pragma: no cover |
| 171 | + # Virchow2 tile encoder: https://huggingface.co/paige-ai/Virchow2 |
| 172 | + # Coverage skipped timm API is tested using efficient U-Net. |
| 173 | + return timm.create_model( |
| 174 | + "hf_hub:paige-ai/Virchow2", |
| 175 | + pretrained=pretrained, |
| 176 | + mlp_layer=SwiGLUPacked, |
| 177 | + act_layer=torch.nn.SiLU |
| 178 | + ) |
| 179 | + |
159 | 180 | msg = f"Backbone {arch_name} not supported. " |
160 | 181 | raise ValueError(msg) |
161 | 182 |
|
@@ -297,6 +318,8 @@ class TimmModel(ModelABC): |
297 | 318 | - "UNI" |
298 | 319 | - "prov-gigapath" |
299 | 320 | - "UNI2" |
| 321 | + - "Virchow" |
| 322 | + - "Virchow2" |
300 | 323 | num_classes (int): |
301 | 324 | Number of classes output by model. |
302 | 325 | pretrained (bool, keyword-only): |
@@ -482,6 +505,8 @@ class TimmBackbone(ModelABC): |
482 | 505 | - "UNI" |
483 | 506 | - "prov-gigapath" |
484 | 507 | - "UNI2" |
| 508 | + - "Virchow" |
| 509 | + - "Virchow2" |
485 | 510 | pretrained (bool, keyword-only): |
486 | 511 | Whether to load pretrained weights. |
487 | 512 |
|
|
0 commit comments