Skip to content

Commit 5b6c20b

Browse files
fix(backbone): update model loading to use CPU mapping
1 parent bc96afb commit 5b6c20b

File tree

7 files changed

+46
-13
lines changed

7 files changed

+46
-13
lines changed

focoos/nn/backbone/convnextv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def __init__(self, config: ConvNeXtV2Config):
158158
cur += depths[i]
159159

160160
if config.use_pretrained and backbone_url:
161-
state = torch.hub.load_state_dict_from_url(backbone_url)
161+
state = torch.hub.load_state_dict_from_url(backbone_url, map_location="cpu")
162162
self.load_state_dict(state)
163163
logger.info(f"Load ConvNeXtV2{config.model_size} state_dict")
164164

focoos/nn/backbone/csp_darknet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,11 @@ def __init__(
369369
}
370370
if config.use_pretrained:
371371
if config.backbone_url:
372-
state = torch.hub.load_state_dict_from_url(config.backbone_url)
372+
state = torch.hub.load_state_dict_from_url(config.backbone_url, map_location="cpu")
373373
self.load_state_dict(state)
374374
logger.info(f"Loaded pretrained weights from {config.backbone_url}")
375375
else:
376-
state = torch.hub.load_state_dict_from_url(CONFIGS[config.size]["url"])
376+
state = torch.hub.load_state_dict_from_url(CONFIGS[config.size]["url"], map_location="cpu")
377377
self.load_state_dict(state)
378378
logger.info(f"Loaded pretrained weights from {CONFIGS[config.size]['url']}")
379379

focoos/nn/backbone/mobilenet_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def __init__(
205205
self.layers.append(layer_name)
206206

207207
if config.use_pretrained and config.backbone_url:
208-
state = torch.hub.load_state_dict_from_url(config.backbone_url)
208+
state = torch.hub.load_state_dict_from_url(config.backbone_url, map_location="cpu")
209209
self.load_state_dict(state)
210210
logger.info("Load MobileNetV2 state_dict")
211211

focoos/nn/backbone/resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def __init__(
227227
self._freeze_norm(self)
228228

229229
if use_pretrained:
230-
state = torch.hub.load_state_dict_from_url(backbone_url)
230+
state = torch.hub.load_state_dict_from_url(backbone_url, map_location="cpu")
231231
self.load_state_dict(state)
232232
logger.info(f"Load ResNet{depth} state_dict")
233233

focoos/nn/backbone/stdc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def __init__(self, config: STDCConfig):
248248
self.out_ids = 1, 5, 10, 13
249249

250250
if config.use_pretrained and config.backbone_url:
251-
state = torch.hub.load_state_dict_from_url(config.backbone_url)
251+
state = torch.hub.load_state_dict_from_url(config.backbone_url, map_location="cpu")
252252
self.load_state_dict(state)
253253
logger.info("Load STDC state_dict")
254254

focoos/nn/backbone/swin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ def __init__(
679679
self.add_module(layer_name, layer)
680680

681681
if config.use_pretrained and backbone_url:
682-
state = torch.hub.load_state_dict_from_url(backbone_url)
682+
state = torch.hub.load_state_dict_from_url(backbone_url, map_location="cpu")
683683
self.load_state_dict(state, strict=False)
684684
logger.info(f"Loaded pretrained weights from {backbone_url}")
685685

tests/test_backbone.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,49 @@
1111
{"model_type": "stdc", "use_pretrained": True, "size": size} for size in ["nano", "small", "large"]
1212
]
1313
stdc_configs_base = [
14-
{"model_type": "stdc", "use_pretrained": True, "base": 64, "layers": [2, 2, 2], "block_num": 4, "block_type": "cat"}
14+
{
15+
"model_type": "stdc",
16+
"use_pretrained": True,
17+
"base": 64,
18+
"layers": [2, 2, 2],
19+
"block_num": 4,
20+
"block_type": "cat",
21+
},
22+
{
23+
"model_type": "stdc",
24+
"use_pretrained": False,
25+
"base": 64,
26+
"layers": [4, 5, 3],
27+
"block_num": 4,
28+
"block_type": "cat",
29+
},
1530
]
1631
BACKBONE_CONFIGS = {
17-
"resnet": [{"model_type": "resnet", "use_pretrained": False, "depth": 18}],
32+
"resnet": [
33+
{"model_type": "resnet", "use_pretrained": True, "depth": 18},
34+
{"model_type": "resnet", "use_pretrained": True, "depth": 34},
35+
{"model_type": "resnet", "use_pretrained": False, "depth": 50},
36+
{"model_type": "resnet", "use_pretrained": False, "depth": 101},
37+
],
1838
"stdc": stdc_configs_size + stdc_configs_base,
19-
"swin": [{"model_type": "swin", "use_pretrained": False}],
20-
"mobilenet_v2": [{"model_type": "mobilenet_v2", "use_pretrained": False}],
21-
"convnextv2": [{"model_type": "convnextv2", "use_pretrained": False}],
39+
"swin": [
40+
{"model_type": "swin", "use_pretrained": False},
41+
{
42+
"model_type": "swin",
43+
"use_pretrained": True,
44+
},
45+
],
46+
"mobilenet_v2": [
47+
{"model_type": "mobilenet_v2", "use_pretrained": False},
48+
{"model_type": "mobilenet_v2", "use_pretrained": True},
49+
],
50+
"convnextv2": [
51+
{"model_type": "convnextv2", "use_pretrained": False},
52+
{"model_type": "convnextv2", "use_pretrained": True},
53+
],
2254
"csp_darknet": [
23-
{"model_type": "csp_darknet", "use_pretrained": False, "size": size} for size in ["small", "medium", "large"]
55+
{"model_type": "csp_darknet", "use_pretrained": True if size == "small" else False, "size": size}
56+
for size in ["small", "medium", "large"]
2457
],
2558
}
2659

0 commit comments

Comments
 (0)