|
17 | 17 | import numpy as np |
18 | 18 | from torchvision.models import WeightsEnum |
19 | 19 |
|
| 20 | +torch_cnn_backbone_dict = { |
| 21 | + "alexnet": torch_models.alexnet, |
| 22 | + "resnet18": torch_models.resnet18, |
| 23 | + "resnet34": torch_models.resnet34, |
| 24 | + "resnet50": torch_models.resnet50, |
| 25 | + "resnet101": torch_models.resnet101, |
| 26 | + "resnext50_32x4d": torch_models.resnext50_32x4d, |
| 27 | + "resnext101_32x8d": torch_models.resnext101_32x8d, |
| 28 | + "wide_resnet50_2": torch_models.wide_resnet50_2, |
| 29 | + "wide_resnet101_2": torch_models.wide_resnet101_2, |
| 30 | + "densenet121": torch_models.densenet121, |
| 31 | + "densenet161": torch_models.densenet161, |
| 32 | + "densenet169": torch_models.densenet169, |
| 33 | + "densenet201": torch_models.densenet201, |
| 34 | + "inception_v3": torch_models.inception_v3, |
| 35 | + "googlenet": torch_models.googlenet, |
| 36 | + "mobilenet_v2": torch_models.mobilenet_v2, |
| 37 | + "mobilenet_v3_large": torch_models.mobilenet_v3_large, |
| 38 | + "mobilenet_v3_small": torch_models.mobilenet_v3_small, |
| 39 | +} |
| 40 | + |
| 41 | +timm_arch_dict = { |
| 42 | + # UNI tile encoder: https://huggingface.co/MahmoodLab/UNI |
| 43 | + "UNI": { |
| 44 | + "model": "hf-hub:MahmoodLab/UNI", |
| 45 | + "init_values": 1e-5, |
| 46 | + "dynamic_img_size": True, |
| 47 | + }, |
| 48 | + # Prov-GigaPath tile encoder: https://huggingface.co/prov-gigapath/prov-gigapath |
| 49 | + "prov-gigapath": {"model": "hf_hub:prov-gigapath/prov-gigapath"}, |
| 50 | + # H-Optimus-0 tile encoder: https://huggingface.co/bioptimus/H-optimus-0 |
| 51 | + "H-optimus-0": { |
| 52 | + "model": "hf-hub:bioptimus/H-optimus-0", |
| 53 | + "init_values": 1e-5, |
| 54 | + "dynamic_img_size": False, |
| 55 | + }, |
| 56 | + # H-Optimus-1 tile encoder: https://huggingface.co/bioptimus/H-optimus-1 |
| 57 | + "H-optimus-1": { |
| 58 | + "model": "hf-hub:bioptimus/H-optimus-1", |
| 59 | + "init_values": 1e-5, |
| 60 | + "dynamic_img_size": False, |
| 61 | + }, |
| 62 | + # HO-mini tile encoder: https://huggingface.co/bioptimus/H0-mini |
| 63 | + "H0-mini": { |
| 64 | + "model": "hf-hub:bioptimus/H0-mini", |
| 65 | + "init_values": 1e-5, |
| 66 | + "dynamic_img_size": False, |
| 67 | + "mlp_layer": timm.layers.SwiGLUPacked, |
| 68 | + "act_layer": torch.nn.SiLU, |
| 69 | + }, |
| 70 | + # UNI2-h tile encoder: https://huggingface.co/MahmoodLab/UNI2-h |
| 71 | + "UNI2": { |
| 72 | + "model": "hf-hub:MahmoodLab/UNI2-h", |
| 73 | + "img_size": 224, |
| 74 | + "patch_size": 14, |
| 75 | + "depth": 24, |
| 76 | + "num_heads": 24, |
| 77 | + "init_values": 1e-5, |
| 78 | + "embed_dim": 1536, |
| 79 | + "mlp_ratio": 2.66667 * 2, |
| 80 | + "num_classes": 0, |
| 81 | + "no_embed_class": True, |
| 82 | + "mlp_layer": timm.layers.SwiGLUPacked, |
| 83 | + "act_layer": torch.nn.SiLU, |
| 84 | + "reg_tokens": 8, |
| 85 | + "dynamic_img_size": True, |
| 86 | + }, |
| 87 | + # Virchow tile encoder: https://huggingface.co/paige-ai/Virchow |
| 88 | + "Virchow": { |
| 89 | + "model": "hf_hub:paige-ai/Virchow", |
| 90 | + "mlp_layer": SwiGLUPacked, |
| 91 | + "act_layer": torch.nn.SiLU, |
| 92 | + }, |
| 93 | + # Virchow2 tile encoder: https://huggingface.co/paige-ai/Virchow2 |
| 94 | + "Virchow2": { |
| 95 | + "model": "hf_hub:paige-ai/Virchow2", |
| 96 | + "mlp_layer": SwiGLUPacked, |
| 97 | + "act_layer": torch.nn.SiLU, |
| 98 | + }, |
| 99 | + # Kaiko tile encoder: |
| 100 | + # https://huggingface.co/1aurent/vit_large_patch14_reg4_224.kaiko_ai_towards_large_pathology_fms |
| 101 | + "kaiko": { |
| 102 | + "model": ( |
| 103 | + "hf_hub:1aurent/" |
| 104 | + "vit_large_patch14_reg4_224.kaiko_ai_towards_large_pathology_fms" |
| 105 | + ), |
| 106 | + "dynamic_img_size": True, |
| 107 | + }, |
| 108 | +} |
| 109 | + |
20 | 110 |
|
21 | 111 | def _get_architecture( |
22 | 112 | arch_name: str, |
@@ -52,31 +142,11 @@ def _get_architecture( |
52 | 142 | >>> print(model) |
53 | 143 |
|
54 | 144 | """ |
55 | | - backbone_dict = { |
56 | | - "alexnet": torch_models.alexnet, |
57 | | - "resnet18": torch_models.resnet18, |
58 | | - "resnet34": torch_models.resnet34, |
59 | | - "resnet50": torch_models.resnet50, |
60 | | - "resnet101": torch_models.resnet101, |
61 | | - "resnext50_32x4d": torch_models.resnext50_32x4d, |
62 | | - "resnext101_32x8d": torch_models.resnext101_32x8d, |
63 | | - "wide_resnet50_2": torch_models.wide_resnet50_2, |
64 | | - "wide_resnet101_2": torch_models.wide_resnet101_2, |
65 | | - "densenet121": torch_models.densenet121, |
66 | | - "densenet161": torch_models.densenet161, |
67 | | - "densenet169": torch_models.densenet169, |
68 | | - "densenet201": torch_models.densenet201, |
69 | | - "inception_v3": torch_models.inception_v3, |
70 | | - "googlenet": torch_models.googlenet, |
71 | | - "mobilenet_v2": torch_models.mobilenet_v2, |
72 | | - "mobilenet_v3_large": torch_models.mobilenet_v3_large, |
73 | | - "mobilenet_v3_small": torch_models.mobilenet_v3_small, |
74 | | - } |
75 | | - if arch_name not in backbone_dict: |
| 145 | + if arch_name not in torch_cnn_backbone_dict: |
76 | 146 | msg = f"Backbone `{arch_name}` is not supported." |
77 | 147 | raise ValueError(msg) |
78 | 148 |
|
79 | | - creator = backbone_dict[arch_name] |
| 149 | + creator = torch_cnn_backbone_dict[arch_name] |
80 | 150 | if "inception_v3" in arch_name or "googlenet" in arch_name: |
81 | 151 | model = creator(weights=weights, aux_logits=False, num_classes=1000) |
82 | 152 | return nn.Sequential(*list(model.children())[:-3]) |
@@ -123,87 +193,18 @@ def _get_timm_architecture( |
123 | 193 | >>> print(model) |
124 | 194 |
|
125 | 195 | """ |
126 | | - if arch_name in [f"efficientnet_b{i}" for i in range(8)]: |
127 | | - model = timm.create_model(arch_name, pretrained=pretrained) |
128 | | - return nn.Sequential(*list(model.children())[:-1]) |
129 | | - |
130 | | - arch_map = { |
131 | | - # UNI tile encoder: https://huggingface.co/MahmoodLab/UNI |
132 | | - "UNI": { |
133 | | - "model": "hf-hub:MahmoodLab/UNI", |
134 | | - "init_values": 1e-5, |
135 | | - "dynamic_img_size": True, |
136 | | - }, |
137 | | - # Prov-GigaPath tile encoder: https://huggingface.co/prov-gigapath/prov-gigapath |
138 | | - "prov-gigapath": {"model": "hf_hub:prov-gigapath/prov-gigapath"}, |
139 | | - # H-Optimus-0 tile encoder: https://huggingface.co/bioptimus/H-optimus-0 |
140 | | - "H-optimus-0": { |
141 | | - "model": "hf-hub:bioptimus/H-optimus-0", |
142 | | - "init_values": 1e-5, |
143 | | - "dynamic_img_size": False, |
144 | | - }, |
145 | | - # H-Optimus-1 tile encoder: https://huggingface.co/bioptimus/H-optimus-1 |
146 | | - "H-optimus-1": { |
147 | | - "model": "hf-hub:bioptimus/H-optimus-1", |
148 | | - "init_values": 1e-5, |
149 | | - "dynamic_img_size": False, |
150 | | - }, |
151 | | - # HO-mini tile encoder: https://huggingface.co/bioptimus/H0-mini |
152 | | - "H0-mini": { |
153 | | - "model": "hf-hub:bioptimus/H0-mini", |
154 | | - "init_values": 1e-5, |
155 | | - "dynamic_img_size": False, |
156 | | - "mlp_layer": timm.layers.SwiGLUPacked, |
157 | | - "act_layer": torch.nn.SiLU, |
158 | | - }, |
159 | | - # UNI2-h tile encoder: https://huggingface.co/MahmoodLab/UNI2-h |
160 | | - "UNI2": { |
161 | | - "model": "hf-hub:MahmoodLab/UNI2-h", |
162 | | - "img_size": 224, |
163 | | - "patch_size": 14, |
164 | | - "depth": 24, |
165 | | - "num_heads": 24, |
166 | | - "init_values": 1e-5, |
167 | | - "embed_dim": 1536, |
168 | | - "mlp_ratio": 2.66667 * 2, |
169 | | - "num_classes": 0, |
170 | | - "no_embed_class": True, |
171 | | - "mlp_layer": timm.layers.SwiGLUPacked, |
172 | | - "act_layer": torch.nn.SiLU, |
173 | | - "reg_tokens": 8, |
174 | | - "dynamic_img_size": True, |
175 | | - }, |
176 | | - # Virchow tile encoder: https://huggingface.co/paige-ai/Virchow |
177 | | - "Virchow": { |
178 | | - "model": "hf_hub:paige-ai/Virchow", |
179 | | - "mlp_layer": SwiGLUPacked, |
180 | | - "act_layer": torch.nn.SiLU, |
181 | | - }, |
182 | | - # Virchow2 tile encoder: https://huggingface.co/paige-ai/Virchow2 |
183 | | - "Virchow2": { |
184 | | - "model": "hf_hub:paige-ai/Virchow2", |
185 | | - "mlp_layer": SwiGLUPacked, |
186 | | - "act_layer": torch.nn.SiLU, |
187 | | - }, |
188 | | - # Kaiko tile encoder: |
189 | | - # https://huggingface.co/1aurent/vit_large_patch14_reg4_224.kaiko_ai_towards_large_pathology_fms |
190 | | - "kaiko": { |
191 | | - "model": ( |
192 | | - "hf_hub:1aurent/" |
193 | | - "vit_large_patch14_reg4_224.kaiko_ai_towards_large_pathology_fms" |
194 | | - ), |
195 | | - "dynamic_img_size": True, |
196 | | - }, |
197 | | - } |
198 | | - |
199 | | - if arch_name in arch_map: # pragma: no cover |
| 196 | + if arch_name in timm_arch_dict: # pragma: no cover |
200 | 197 | # Coverage skipped timm API is tested using efficient U-Net. |
201 | 198 | return timm.create_model( |
202 | | - arch_map[arch_name].pop("model"), |
| 199 | + timm_arch_dict[arch_name].pop("model"), |
203 | 200 | pretrained=pretrained, |
204 | | - **arch_map[arch_name], |
| 201 | + **timm_arch_dict[arch_name], |
205 | 202 | ) |
206 | 203 |
|
| 204 | + if arch_name in timm.list_models(): |
| 205 | + model = timm.create_model(arch_name, pretrained=pretrained) |
| 206 | + return nn.Sequential(*list(model.children())[:-1]) |
| 207 | + |
207 | 208 | msg = f"Backbone {arch_name} not supported. " |
208 | 209 | raise ValueError(msg) |
209 | 210 |
|
|
0 commit comments