Skip to content

Commit 68071cf

Browse files
committed
use outer config.py
1 parent 9bb710f commit 68071cf

File tree

10 files changed

+38
-31
lines changed

10 files changed

+38
-31
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ My platform is like this:
2929
## get start
3030
With a pretrained weight, you can run inference on an single image like this:
3131
```
32-
$ python tools/demo.py --model bisenetv2 --weight-path /path/to/your/weights.pth --img-path ./example.png
32+
$ python tools/demo.py --config configs/bisenetv2_city.py --weight-path /path/to/your/weights.pth --img-path ./example.png
3333
```
3434
This would run inference on the image and save the result image to `./res.jpg`.
3535

@@ -64,10 +64,10 @@ In order to train the model, you can run command like this:
6464
$ export CUDA_VISIBLE_DEVICES=0,1
6565
6666
# if you want to train with apex
67-
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --model bisenetv2 # or bisenetv1
67+
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --config configs/bisenetv2_city.py # or bisenetv1
6868
6969
# if you want to train with pytorch fp16 feature from torch 1.6
70-
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --model bisenetv2 # or bisenetv1
70+
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --config configs/bisenetv2_city.py # or bisenetv1
7171
```
7272

7373
Note that though `bisenetv2` has fewer flops, it requires much more training iterations. The the training time of `bisenetv1` is shorter.
@@ -77,17 +77,17 @@ Note that though `bisenetv2` has fewer flops, it requires much more training ite
7777
You can also load the trained model weights and finetune from it:
7878
```
7979
$ export CUDA_VISIBLE_DEVICES=0,1
80-
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --finetune-from ./res/model_final.pth --model bisenetv2 # or bisenetv1
80+
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --finetune-from ./res/model_final.pth --config ./configs/bisenetv2_city.py # or bisenetv1
8181
8282
# same with pytorch fp16 feature
83-
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --finetune-from ./res/model_final.pth --model bisenetv2 # or bisenetv1
83+
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --finetune-from ./res/model_final.pth --config ./configs/bisenetv2_city.py # or bisenetv1
8484
```
8585

8686

8787
## eval pretrained models
8888
You can also evaluate a trained model like this:
8989
```
90-
$ python tools/evaluate.py --model bisenetv1 --weight-path /path/to/your/weight.pth
90+
$ python tools/evaluate.py --config configs/bisenetv1_city.py --weight-path /path/to/your/weight.pth
9191
```
9292

9393
## Infer with tensorrt

configs/__init__.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11

2-
from .bisenetv1 import cfg as bisenetv1_cfg
3-
from .bisenetv2 import cfg as bisenetv2_cfg
42

3+
import importlib
54

65

76
class cfg_dict(object):
@@ -10,7 +9,12 @@ def __init__(self, d):
109
self.__dict__ = d
1110

1211

13-
cfg_factory = dict(
14-
bisenetv1=cfg_dict(bisenetv1_cfg),
15-
bisenetv2=cfg_dict(bisenetv2_cfg),
16-
)
12+
def set_cfg_from_file(cfg_path):
13+
spec = importlib.util.spec_from_file_location('cfg_file', cfg_path)
14+
cfg_file = importlib.util.module_from_spec(spec)
15+
spec_loader = spec.loader.exec_module(cfg_file)
16+
cfg = cfg_file.cfg
17+
return cfg_dict(cfg)
18+
19+
20+
File renamed without changes.
File renamed without changes.

dist_train.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
export CUDA_VISIBLE_DEVICES=6,7
3-
PORT=52332
2+
export CUDA_VISIBLE_DEVICES=0,1
3+
PORT=52335
44
NGPUS=2
55

6-
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --model bisenetv2 --port $PORT
6+
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config configs/bisenetv1_city.py --port $PORT

tools/demo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@
1010

1111
import lib.transform_cv2 as T
1212
from lib.models import model_factory
13-
from configs import cfg_factory
13+
from configs import set_cfg_from_file
1414

1515
torch.set_grad_enabled(False)
1616
np.random.seed(123)
1717

1818

1919
# args
2020
parse = argparse.ArgumentParser()
21-
parse.add_argument('--model', dest='model', type=str, default='bisenetv2',)
21+
parse.add_argument('--config', dest='config', type=str, default='configs/bisenetv2.py',)
2222
parse.add_argument('--weight-path', type=str, default='./res/model_final.pth',)
2323
parse.add_argument('--img-path', dest='img_path', type=str, default='./example.png',)
2424
args = parse.parse_args()
25-
cfg = cfg_factory[args.model]
25+
cfg = set_cfg_from_file(args.config)
2626

2727

2828
palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8)

tools/evaluate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.distributed as dist
2121

2222
from lib.models import model_factory
23-
from configs import cfg_factory
23+
from configs import set_cfg_from_file
2424
from lib.logger import setup_logger
2525
from lib.cityscapes_cv2 import get_data_loader
2626

@@ -262,13 +262,14 @@ def parse_args():
262262
parse.add_argument('--weight-path', dest='weight_pth', type=str,
263263
default='model_final.pth',)
264264
parse.add_argument('--port', dest='port', type=int, default=44553,)
265-
parse.add_argument('--model', dest='model', type=str, default='bisenetv2',)
265+
parse.add_argument('--config', dest='config', type=str,
266+
default='configs/bisenetv2.py',)
266267
return parse.parse_args()
267268

268269

269270
def main():
270271
args = parse_args()
271-
cfg = cfg_factory[args.model]
272+
cfg = set_cfg_from_file(args.config)
272273
if not args.local_rank == -1:
273274
torch.cuda.set_device(args.local_rank)
274275
dist.init_process_group(backend='nccl',

tools/export_onnx.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,22 @@
66
import torch
77

88
from lib.models import model_factory
9-
from configs import cfg_factory
9+
from configs import set_cfg_from_file
1010

1111
torch.set_grad_enabled(False)
1212

1313

1414
parse = argparse.ArgumentParser()
15-
parse.add_argument('--model', dest='model', type=str, default='bisenetv1',)
15+
parse.add_argument('--config', dest='config', type=str,
16+
default='configs/bisenetv2.py',)
1617
parse.add_argument('--weight-path', dest='weight_pth', type=str,
1718
default='model_final.pth')
1819
parse.add_argument('--outpath', dest='out_pth', type=str,
1920
default='model.onnx')
2021
args = parse.parse_args()
2122

2223

23-
cfg = cfg_factory[args.model]
24+
cfg = set_cfg_from_file(args.config)
2425
if cfg.use_sync_bn: cfg.use_sync_bn = False
2526

2627
net = model_factory[cfg.model_type](19, output_aux=False)

tools/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.utils.data import DataLoader
1919

2020
from lib.models import model_factory
21-
from configs import cfg_factory
21+
from configs import set_cfg_from_file
2222
from lib.cityscapes_cv2 import get_data_loader
2323
from tools.evaluate import eval_model
2424
from lib.ohem_ce_loss import OhemCELoss
@@ -50,13 +50,13 @@ def parse_args():
5050
parse = argparse.ArgumentParser()
5151
parse.add_argument('--local_rank', dest='local_rank', type=int, default=-1,)
5252
parse.add_argument('--port', dest='port', type=int, default=44554,)
53-
parse.add_argument('--model', dest='model', type=str, default='bisenetv2',)
53+
parse.add_argument('--config', dest='config', type=str,
54+
default='configs/bisenetv2.py',)
5455
parse.add_argument('--finetune-from', type=str, default=None,)
5556
return parse.parse_args()
5657

5758
args = parse_args()
58-
cfg = cfg_factory[args.model]
59-
59+
cfg = set_cfg_from_file(args.config)
6060

6161

6262
def set_model():

tools/train_amp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch.cuda.amp as amp
2020

2121
from lib.models import model_factory
22-
from configs import cfg_factory
22+
from configs import set_cfg_from_file
2323
from lib.cityscapes_cv2 import get_data_loader
2424
from evaluate import eval_model
2525
from lib.ohem_ce_loss import OhemCELoss
@@ -44,12 +44,13 @@ def parse_args():
4444
parse = argparse.ArgumentParser()
4545
parse.add_argument('--local_rank', dest='local_rank', type=int, default=-1,)
4646
parse.add_argument('--port', dest='port', type=int, default=44554,)
47-
parse.add_argument('--model', dest='model', type=str, default='bisenetv2',)
47+
parse.add_argument('--config', dest='config', type=str,
48+
default='configs/bisenetv2.py',)
4849
parse.add_argument('--finetune-from', type=str, default=None,)
4950
return parse.parse_args()
5051

5152
args = parse_args()
52-
cfg = cfg_factory[args.model]
53+
cfg = set_cfg_from_file(args.config)
5354

5455

5556

0 commit comments

Comments
 (0)