|
9 | 9 | """ |
10 | 10 | from __future__ import annotations |
11 | 11 |
|
12 | | -from pathlib import Path |
13 | 12 | from typing import Literal |
14 | 13 |
|
15 | 14 | import timm |
16 | 15 | import torch |
17 | 16 | from torch import nn |
18 | 17 |
|
19 | | -from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http |
| 18 | +from otx.algo.utils.mmengine_utils import load_from_http |
20 | 19 |
|
21 | 20 | PRETRAINED_ROOT = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/" |
22 | 21 | pretrained_urls = { |
|
28 | 27 | "mobilenetv3_large_21k": "mobilenetv3_large_100_miil_in21k", |
29 | 28 | "mobilenetv3_large_1k": "mobilenetv3_large_100_miil", |
30 | 29 | "tresnet": "tresnet_m", |
31 | | - "efficientnetv2_s_21k": "tf_efficientnetv2_s_in21k", |
| 30 | + "efficientnetv2_s_21k": "tf_efficientnetv2_s.in21k", |
32 | 31 | "efficientnetv2_s_1k": "tf_efficientnetv2_s_in21ft1k", |
33 | 32 | "efficientnetv2_m_21k": "tf_efficientnetv2_m_in21k", |
34 | 33 | "efficientnetv2_m_1k": "tf_efficientnetv2_m_in21ft1k", |
@@ -59,12 +58,19 @@ def __init__( |
59 | 58 | ): |
60 | 59 | super().__init__(**kwargs) |
61 | 60 | self.backbone = backbone |
62 | | - self.pretrained = pretrained |
| 61 | + self.pretrained: bool | dict = pretrained |
63 | 62 | self.is_mobilenet = backbone.startswith("mobilenet") |
| 63 | + if pretrained and self.backbone in pretrained_urls: |
| 64 | + # This pretrained weight is saved into ~/.cache/torch/hub/checkpoints |
| 65 | + # Otherwise, it is stored in ~/.cache/huggingface/hub. (timm defaults) |
| 66 | + self.pretrained = load_from_http(filename=pretrained_urls[self.backbone]) |
| 67 | + |
| 68 | + self.model = timm.create_model( |
| 69 | + TIMM_MODEL_NAME_DICT[self.backbone], |
| 70 | + pretrained=self.pretrained, |
| 71 | + num_classes=1000, |
| 72 | + ) |
64 | 73 |
|
65 | | - self.model = timm.create_model(TIMM_MODEL_NAME_DICT[self.backbone], pretrained=pretrained, num_classes=1000) |
66 | | - if self.pretrained: |
67 | | - print(f"init weight - {pretrained_urls[self.backbone]}") |
68 | 74 | self.model.classifier = None # Detach classifier. Only use 'backbone' part in otx. |
69 | 75 | self.num_head_features = self.model.num_features |
70 | 76 | self.num_features = self.model.conv_head.in_channels if self.is_mobilenet else self.model.num_features |
@@ -97,15 +103,3 @@ def get_config_optim(self, lrs: list[float] | float) -> list[dict[str, float]]: |
97 | 103 | param_dict["lr"] = lrs |
98 | 104 |
|
99 | 105 | return parameters |
100 | | - |
101 | | - def init_weights(self, pretrained: str | bool | None = None) -> None: |
102 | | - """Initialize weights.""" |
103 | | - checkpoint = None |
104 | | - if isinstance(pretrained, str) and Path(pretrained).exists(): |
105 | | - checkpoint = torch.load(pretrained, None) |
106 | | - print(f"init weight - {pretrained}") |
107 | | - elif pretrained is not None: |
108 | | - checkpoint = load_from_http(pretrained_urls[self.key]) |
109 | | - print(f"init weight - {pretrained_urls[self.key]}") |
110 | | - if checkpoint is not None: |
111 | | - load_checkpoint_to_model(self, checkpoint) |
0 commit comments