Skip to content

Commit 9bb710f

Browse files
committed
use pretrain backbone
1 parent 9ff609a commit 9bb710f

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

configs/bisenetv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
cfg = dict(
44
model_type='bisenetv2',
55
num_aux_heads=4,
6-
lr_start = 1 * 5e-3,
6+
lr_start = 5e-3,
77
weight_decay=5e-4,
88
warmup_iters = 1000,
99
max_iter = 150000,

dist_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
export CUDA_VISIBLE_DEVICES=6,7
3-
PORT=52330
3+
PORT=52332
44
NGPUS=2
55

66
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --model bisenetv2 --port $PORT

lib/models/bisenetv2.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import torch
33
import torch.nn as nn
44
import torch.nn.functional as F
5+
import torch.utils.model_zoo as modelzoo
6+
7+
backbone_url = 'https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/backbone_v2.pth'
58

69

710
class ConvBNReLU(nn.Module):
@@ -325,7 +328,6 @@ def __init__(self, n_classes, output_aux=True):
325328
self.aux5_4 = SegmentHead(128, 128, n_classes, up_factor=32)
326329

327330
self.init_weights()
328-
self.load_pretrain()
329331

330332
def forward(self, x):
331333
size = x.size()[2:]
@@ -354,13 +356,13 @@ def init_weights(self):
354356
else:
355357
nn.init.ones_(module.weight)
356358
nn.init.zeros_(module.bias)
359+
self.load_pretrain()
357360

358361
def load_pretrain(self):
359-
state = torch.load('pretrained/bisenetv2_pretrain.pth', map_location='cpu')
360-
state = {k:v for k,v in state.items() if not k in ('fc', 'head', 'dense_head')}
362+
state = modelzoo.load_url(backbone_url)
361363
for name, child in self.named_children():
362364
if name in state.keys():
363-
child.load_state_dict(state[name])
365+
child.load_state_dict(state[name], strict=True)
364366

365367

366368
def get_params(self):

0 commit comments

Comments
 (0)