Skip to content

Commit afe86f7

Browse files
committed
refactor and change dataloader method
1 parent 68071cf commit afe86f7

File tree

8 files changed

+75
-64
lines changed

8 files changed

+75
-64
lines changed

configs/bisenetv1_city.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
weight_decay=5e-4,
77
warmup_iters=1000,
88
max_iter=80000,
9+
dataset='CityScapes',
910
im_root='./datasets/cityscapes',
1011
train_im_anns='./datasets/cityscapes/train.txt',
1112
val_im_anns='./datasets/cityscapes/val.txt',
1213
scales=[0.75, 2.],
1314
cropsize=[1024, 1024],
1415
ims_per_gpu=8,
16+
eval_ims_per_gpu=2,
1517
use_fp16=True,
1618
use_sync_bn=False,
1719
respth='./res',

configs/bisenetv2_city.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
weight_decay=5e-4,
88
warmup_iters = 1000,
99
max_iter = 150000,
10+
dataset='CityScapes',
1011
im_root='./datasets/cityscapes',
1112
train_im_anns='./datasets/cityscapes/train.txt',
1213
val_im_anns='./datasets/cityscapes/val.txt',
1314
scales=[0.25, 2.],
1415
cropsize=[512, 1024],
1516
ims_per_gpu=8,
17+
eval_ims_per_gpu=2,
1618
use_fp16=True,
1719
use_sync_bn=True,
1820
respth='./res',

dist_train.sh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11

2-
export CUDA_VISIBLE_DEVICES=0,1
3-
PORT=52335
2+
export CUDA_VISIBLE_DEVICES=2,3
3+
PORT=52330
44
NGPUS=2
5+
cfg_file=configs/bisenetv1_city.py
56

6-
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config configs/bisenetv1_city.py --port $PORT
7+
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file --port $PORT
8+
9+
# python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train.py --config $cfg_file --port $PORT

lib/cityscapes_cv2.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
import numpy as np
1313

1414
import lib.transform_cv2 as T
15-
from lib.sampler import RepeatedDistSampler
16-
from lib.base_dataset import BaseDataset, TransformationTrain, TransformationVal
15+
from lib.base_dataset import BaseDataset
1716

1817

1918
labels_info = [
@@ -74,48 +73,6 @@ def __init__(self, dataroot, annpath, trans_func=None, mode='train'):
7473
)
7574

7675

77-
def get_data_loader(datapth, annpath, ims_per_gpu, scales, cropsize, max_iter=None, mode='train', distributed=True):
78-
if mode == 'train':
79-
trans_func = TransformationTrain(scales, cropsize)
80-
batchsize = ims_per_gpu
81-
shuffle = True
82-
drop_last = True
83-
elif mode == 'val':
84-
trans_func = TransformationVal()
85-
batchsize = ims_per_gpu
86-
shuffle = False
87-
drop_last = False
88-
89-
ds = CityScapes(datapth, annpath, trans_func=trans_func, mode=mode)
90-
91-
if distributed:
92-
assert dist.is_available(), "dist should be initialzed"
93-
if mode == 'train':
94-
assert not max_iter is None
95-
n_train_imgs = ims_per_gpu * dist.get_world_size() * max_iter
96-
sampler = RepeatedDistSampler(ds, n_train_imgs, shuffle=shuffle)
97-
else:
98-
sampler = torch.utils.data.distributed.DistributedSampler(
99-
ds, shuffle=shuffle)
100-
batchsampler = torch.utils.data.sampler.BatchSampler(
101-
sampler, batchsize, drop_last=drop_last
102-
)
103-
dl = DataLoader(
104-
ds,
105-
batch_sampler=batchsampler,
106-
num_workers=4,
107-
pin_memory=True,
108-
)
109-
else:
110-
dl = DataLoader(
111-
ds,
112-
batch_size=batchsize,
113-
shuffle=shuffle,
114-
drop_last=drop_last,
115-
num_workers=4,
116-
pin_memory=True,
117-
)
118-
return dl
11976

12077

12178

lib/get_dataloader.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
2+
import torch
3+
from torch.utils.data import Dataset, DataLoader
4+
import torch.distributed as dist
5+
6+
from lib.sampler import RepeatedDistSampler
7+
from lib.base_dataset import TransformationTrain, TransformationVal
8+
from lib.cityscapes_cv2 import CityScapes
9+
10+
11+
def get_data_loader(cfg, mode='train', distributed=True):
12+
if mode == 'train':
13+
trans_func = TransformationTrain(cfg.scales, cfg.cropsize)
14+
batchsize = cfg.ims_per_gpu
15+
annpath = cfg.train_im_anns
16+
shuffle = True
17+
drop_last = True
18+
elif mode == 'val':
19+
trans_func = TransformationVal()
20+
batchsize = cfg.eval_ims_per_gpu
21+
annpath = cfg.val_im_anns
22+
shuffle = False
23+
drop_last = False
24+
25+
ds = eval(cfg.dataset)(cfg.im_root, annpath, trans_func=trans_func, mode=mode)
26+
27+
if distributed:
28+
assert dist.is_available(), "dist should be initialzed"
29+
if mode == 'train':
30+
assert not cfg.max_iter is None
31+
n_train_imgs = cfg.ims_per_gpu * dist.get_world_size() * cfg.max_iter
32+
sampler = RepeatedDistSampler(ds, n_train_imgs, shuffle=shuffle)
33+
else:
34+
sampler = torch.utils.data.distributed.DistributedSampler(
35+
ds, shuffle=shuffle)
36+
batchsampler = torch.utils.data.sampler.BatchSampler(
37+
sampler, batchsize, drop_last=drop_last
38+
)
39+
dl = DataLoader(
40+
ds,
41+
batch_sampler=batchsampler,
42+
num_workers=4,
43+
pin_memory=True,
44+
)
45+
else:
46+
dl = DataLoader(
47+
ds,
48+
batch_size=batchsize,
49+
shuffle=shuffle,
50+
drop_last=drop_last,
51+
num_workers=4,
52+
pin_memory=True,
53+
)
54+
return dl

tools/evaluate.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from lib.models import model_factory
2323
from configs import set_cfg_from_file
2424
from lib.logger import setup_logger
25-
from lib.cityscapes_cv2 import get_data_loader
25+
from lib.get_dataloader import get_data_loader
2626

2727

2828

@@ -184,10 +184,9 @@ def __call__(self, net, dl, n_classes):
184184

185185

186186
@torch.no_grad()
187-
def eval_model(net, ims_per_gpu, im_root, im_anns):
187+
def eval_model(cfg, net):
188188
is_dist = dist.is_initialized()
189-
dl = get_data_loader(im_root, im_anns, ims_per_gpu, None,
190-
None, mode='val', distributed=is_dist)
189+
dl = get_data_loader(cfg, mode='val', distributed=is_dist)
191190
net.eval()
192191

193192
heads, mious = [], []
@@ -251,7 +250,7 @@ def evaluate(cfg, weight_pth):
251250
)
252251

253252
## evaluator
254-
heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns)
253+
heads, mious = eval_model(cfg, net)
255254
logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl'))
256255

257256

tools/train.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from lib.models import model_factory
2121
from configs import set_cfg_from_file
22-
from lib.cityscapes_cv2 import get_data_loader
22+
from lib.get_dataloader import get_data_loader
2323
from tools.evaluate import eval_model
2424
from lib.ohem_ce_loss import OhemCELoss
2525
from lib.lr_scheduler import WarmupPolyLrScheduler
@@ -133,10 +133,7 @@ def train():
133133
is_dist = dist.is_initialized()
134134

135135
## dataset
136-
dl = get_data_loader(
137-
cfg.im_root, cfg.train_im_anns,
138-
cfg.ims_per_gpu, cfg.scales, cfg.cropsize,
139-
cfg.max_iter, mode='train', distributed=is_dist)
136+
dl = get_data_loader(cfg, mode='train', distributed=is_dist)
140137

141138
## model
142139
net, criteria_pre, criteria_aux = set_model()
@@ -202,7 +199,7 @@ def train():
202199

203200
logger.info('\nevaluating the final model')
204201
torch.cuda.empty_cache()
205-
heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns)
202+
heads, mious = eval_model(cfg, net)
206203
logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl'))
207204

208205
return

tools/train_amp.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from lib.models import model_factory
2222
from configs import set_cfg_from_file
23-
from lib.cityscapes_cv2 import get_data_loader
23+
from lib.get_dataloader import get_data_loader
2424
from evaluate import eval_model
2525
from lib.ohem_ce_loss import OhemCELoss
2626
from lib.lr_scheduler import WarmupPolyLrScheduler
@@ -122,10 +122,7 @@ def train():
122122
is_dist = dist.is_initialized()
123123

124124
## dataset
125-
dl = get_data_loader(
126-
cfg.im_root, cfg.train_im_anns,
127-
cfg.ims_per_gpu, cfg.scales, cfg.cropsize,
128-
cfg.max_iter, mode='train', distributed=is_dist)
125+
dl = get_data_loader(cfg, mode='train', distributed=is_dist)
129126

130127
## model
131128
net, criteria_pre, criteria_aux = set_model()
@@ -187,7 +184,7 @@ def train():
187184

188185
logger.info('\nevaluating the final model')
189186
torch.cuda.empty_cache()
190-
heads, mious = eval_model(net, 2, cfg.im_root, cfg.val_im_anns)
187+
heads, mious = eval_model(cfg, net)
191188
logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl'))
192189

193190
return

0 commit comments

Comments
 (0)