Skip to content

Commit 9265c59

Browse files
authored
Fix cached dir for timm & hugging-face (#3914)
* Fix cached dir * Pretrained weight download unit-test * Fix pre-commit
1 parent 52221e3 commit 9265c59

File tree

8 files changed

+42
-42
lines changed

8 files changed

+42
-42
lines changed

src/otx/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,19 @@
55

66
__version__ = "2.2.0rc0"
77

8+
import os
9+
from pathlib import Path
10+
811
from otx.core.types import * # noqa: F403
912

13+
# Set the value of HF_HUB_CACHE to set the cache folder that stores the pretrained weights for timm and huggingface.
14+
# Refer: huggingface_hub/constants.py::HF_HUB_CACHE
15+
# Default, Pretrained weight is saved into ~/.cache/torch/hub/checkpoints
16+
os.environ["HF_HUB_CACHE"] = os.getenv(
17+
"HF_HUB_CACHE",
18+
str(Path.home() / ".cache" / "torch" / "hub" / "checkpoints"),
19+
)
20+
1021
OTX_LOGO: str = """
1122
1223
██████╗ ████████╗ ██╗ ██╗

src/otx/algo/classification/backbones/timm.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,15 @@
1515
import torch
1616
from torch import nn
1717

18-
from otx.algo.utils.mmengine_utils import load_from_http
19-
20-
PRETRAINED_ROOT = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/"
21-
pretrained_urls = {
22-
"efficientnetv2_s_21k": PRETRAINED_ROOT + "tf_efficientnetv2_s_21k-6337ad01.pth",
23-
"efficientnetv2_s_1k": PRETRAINED_ROOT + "tf_efficientnetv2_s_21ft1k-d7dafa41.pth",
24-
}
25-
26-
TIMM_MODEL_NAME_DICT = {
27-
"mobilenetv3_large_21k": "mobilenetv3_large_100_miil_in21k",
28-
"mobilenetv3_large_1k": "mobilenetv3_large_100_miil",
29-
"tresnet": "tresnet_m",
30-
"efficientnetv2_s_21k": "tf_efficientnetv2_s.in21k",
31-
"efficientnetv2_s_1k": "tf_efficientnetv2_s_in21ft1k",
32-
"efficientnetv2_m_21k": "tf_efficientnetv2_m_in21k",
33-
"efficientnetv2_m_1k": "tf_efficientnetv2_m_in21ft1k",
34-
"efficientnetv2_b0": "tf_efficientnetv2_b0",
35-
}
36-
3718
TimmModelType = Literal[
38-
"mobilenetv3_large_21k",
39-
"mobilenetv3_large_1k",
40-
"tresnet",
41-
"efficientnetv2_s_21k",
42-
"efficientnetv2_s_1k",
43-
"efficientnetv2_m_21k",
44-
"efficientnetv2_m_1k",
45-
"efficientnetv2_b0",
19+
"mobilenetv3_large_100_miil_in21k",
20+
"mobilenetv3_large_100_miil",
21+
"tresnet_m",
22+
"tf_efficientnetv2_s.in21k",
23+
"tf_efficientnetv2_s.in21ft1k",
24+
"tf_efficientnetv2_m.in21k",
25+
"tf_efficientnetv2_m.in21ft1k",
26+
"tf_efficientnetv2_b0",
4627
]
4728

4829

@@ -60,14 +41,10 @@ def __init__(
6041
self.backbone = backbone
6142
self.pretrained: bool | dict = pretrained
6243
self.is_mobilenet = backbone.startswith("mobilenet")
63-
if pretrained and self.backbone in pretrained_urls:
64-
# This pretrained weight is saved into ~/.cache/torch/hub/checkpoints
65-
# Otherwise, it is stored in ~/.cache/huggingface/hub. (timm defaults)
66-
self.pretrained = load_from_http(filename=pretrained_urls[self.backbone])
6744

6845
self.model = timm.create_model(
69-
TIMM_MODEL_NAME_DICT[self.backbone],
70-
pretrained=self.pretrained,
46+
self.backbone,
47+
pretrained=pretrained,
7148
num_classes=1000,
7249
)
7350

src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
model:
22
class_path: otx.algo.classification.timm_model.TimmModelForHLabelCls
33
init_args:
4-
backbone: efficientnetv2_s_21k
4+
backbone: tf_efficientnetv2_s.in21k
55

66
optimizer:
77
class_path: torch.optim.SGD

src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ model:
22
class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls
33
init_args:
44
label_info: 1000
5-
backbone: efficientnetv2_s_21k
5+
backbone: tf_efficientnetv2_s.in21k
66

77
optimizer:
88
class_path: torch.optim.SGD

src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ model:
22
class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls
33
init_args:
44
label_info: 1000
5-
backbone: efficientnetv2_s_21k
5+
backbone: tf_efficientnetv2_s.in21k
66
train_type: SEMI_SUPERVISED
77

88
optimizer:

src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ model:
22
class_path: otx.algo.classification.timm_model.TimmModelForMultilabelCls
33
init_args:
44
label_info: 1000
5-
backbone: efficientnetv2_s_21k
5+
backbone: tf_efficientnetv2_s.in21k
66

77
optimizer:
88
class_path: torch.optim.SGD
Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import os
5+
import shutil
6+
from pathlib import Path
7+
48
import torch
59
from otx.algo.classification.backbones.timm import TimmBackbone
610

711

812
class TestOTXEfficientNetV2:
913
def test_forward(self):
10-
model = TimmBackbone(backbone="efficientnetv2_s_21k")
14+
model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k")
1115
assert model(torch.randn(1, 3, 244, 244))[0].shape == torch.Size([1, 1280, 8, 8])
1216

1317
def test_get_config_optim(self):
14-
model = TimmBackbone(backbone="efficientnetv2_s_21k")
18+
model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k")
1519
assert model.get_config_optim([0.01])[0]["lr"] == 0.01
1620
assert model.get_config_optim(0.01)[0]["lr"] == 0.01
21+
22+
def test_check_pretrained_weight_download(self):
23+
target = Path(os.environ.get("HF_HUB_CACHE")) / "models--timm--tf_efficientnetv2_s.in21k"
24+
if target.exists():
25+
shutil.rmtree(target)
26+
assert not target.exists()
27+
TimmBackbone(backbone="tf_efficientnetv2_s.in21k", pretrained=True)
28+
assert target.exists()

tests/unit/algo/classification/test_timm_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
def fxt_multi_class_cls_model():
2222
return TimmModelForMulticlassCls(
2323
label_info=10,
24-
backbone="efficientnetv2_s_21k",
24+
backbone="tf_efficientnetv2_s.in21k",
2525
)
2626

2727

@@ -59,7 +59,7 @@ def test_predict_step(self, fxt_multi_class_cls_model, fxt_multiclass_cls_batch_
5959
def fxt_multi_label_cls_model():
6060
return TimmModelForMultilabelCls(
6161
label_info=10,
62-
backbone="efficientnetv2_s_21k",
62+
backbone="tf_efficientnetv2_s.in21k",
6363
)
6464

6565

@@ -97,7 +97,7 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_
9797
def fxt_h_label_cls_model(fxt_hlabel_cifar):
9898
return TimmModelForHLabelCls(
9999
label_info=fxt_hlabel_cifar,
100-
backbone="efficientnetv2_s_21k",
100+
backbone="tf_efficientnetv2_s.in21k",
101101
)
102102

103103

0 commit comments

Comments
 (0)