Skip to content

Commit 4c8555e

Browse files
authored
Fix pretrained weight cached dir for timm (#3909)
* Fix pretrained_weight for timm * Fix unit-test
1 parent 8115b52 commit 4c8555e

File tree

2 files changed

+13
-20
lines changed
  • src/otx/algo/classification/backbones
  • tests/unit/algo/classification/backbones

2 files changed

+13
-20
lines changed

src/otx/algo/classification/backbones/timm.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@
99
"""
1010
from __future__ import annotations
1111

12-
from pathlib import Path
1312
from typing import Literal
1413

1514
import timm
1615
import torch
1716
from torch import nn
1817

19-
from otx.algo.utils.mmengine_utils import load_checkpoint_to_model, load_from_http
18+
from otx.algo.utils.mmengine_utils import load_from_http
2019

2120
PRETRAINED_ROOT = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/"
2221
pretrained_urls = {
@@ -28,7 +27,7 @@
2827
"mobilenetv3_large_21k": "mobilenetv3_large_100_miil_in21k",
2928
"mobilenetv3_large_1k": "mobilenetv3_large_100_miil",
3029
"tresnet": "tresnet_m",
31-
"efficientnetv2_s_21k": "tf_efficientnetv2_s_in21k",
30+
"efficientnetv2_s_21k": "tf_efficientnetv2_s.in21k",
3231
"efficientnetv2_s_1k": "tf_efficientnetv2_s_in21ft1k",
3332
"efficientnetv2_m_21k": "tf_efficientnetv2_m_in21k",
3433
"efficientnetv2_m_1k": "tf_efficientnetv2_m_in21ft1k",
@@ -59,12 +58,19 @@ def __init__(
5958
):
6059
super().__init__(**kwargs)
6160
self.backbone = backbone
62-
self.pretrained = pretrained
61+
self.pretrained: bool | dict = pretrained
6362
self.is_mobilenet = backbone.startswith("mobilenet")
63+
if pretrained and self.backbone in pretrained_urls:
64+
# This pretrained weight is saved into ~/.cache/torch/hub/checkpoints
65+
# Otherwise, it is stored in ~/.cache/huggingface/hub. (timm defaults)
66+
self.pretrained = load_from_http(filename=pretrained_urls[self.backbone])
67+
68+
self.model = timm.create_model(
69+
TIMM_MODEL_NAME_DICT[self.backbone],
70+
pretrained=self.pretrained,
71+
num_classes=1000,
72+
)
6473

65-
self.model = timm.create_model(TIMM_MODEL_NAME_DICT[self.backbone], pretrained=pretrained, num_classes=1000)
66-
if self.pretrained:
67-
print(f"init weight - {pretrained_urls[self.backbone]}")
6874
self.model.classifier = None # Detach classifier. Only use 'backbone' part in otx.
6975
self.num_head_features = self.model.num_features
7076
self.num_features = self.model.conv_head.in_channels if self.is_mobilenet else self.model.num_features
@@ -97,15 +103,3 @@ def get_config_optim(self, lrs: list[float] | float) -> list[dict[str, float]]:
97103
param_dict["lr"] = lrs
98104

99105
return parameters
100-
101-
def init_weights(self, pretrained: str | bool | None = None) -> None:
102-
"""Initialize weights."""
103-
checkpoint = None
104-
if isinstance(pretrained, str) and Path(pretrained).exists():
105-
checkpoint = torch.load(pretrained, None)
106-
print(f"init weight - {pretrained}")
107-
elif pretrained is not None:
108-
checkpoint = load_from_http(pretrained_urls[self.key])
109-
print(f"init weight - {pretrained_urls[self.key]}")
110-
if checkpoint is not None:
111-
load_checkpoint_to_model(self, checkpoint)

tests/unit/algo/classification/backbones/test_timm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
class TestOTXEfficientNetV2:
99
def test_forward(self):
1010
model = TimmBackbone(backbone="efficientnetv2_s_21k")
11-
model.init_weights()
1211
assert model(torch.randn(1, 3, 244, 244))[0].shape == torch.Size([1, 1280, 8, 8])
1312

1413
def test_get_config_optim(self):

0 commit comments

Comments
 (0)