-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathrun.py
More file actions
37 lines (30 loc) · 1.11 KB
/
run.py
File metadata and controls
37 lines (30 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import argparse
from configs import get_cfg
from util.net import init_training
from util.util import run_pre, init_checkpoint
from trainer import get_trainer
import warnings
warnings.filterwarnings("ignore")
def main():
parser = argparse.ArgumentParser()
### unsupervised
parser.add_argument('-c', '--cfg_path', default='configs/benchmark/cdo/cdo_default.py')
parser.add_argument('-m', '--mode', default='train', choices=['train', 'test'])
parser.add_argument('--sleep', type=int, default=-1)
parser.add_argument('--memory', type=int, default=-1)
parser.add_argument('--dist_url', default='env://', type=str, help='url used to set up distributed training')
parser.add_argument('--logger_rank', default=0, type=int, help='GPU id to use.')
parser.add_argument('opts', help='path.key=value', default=None, nargs=argparse.REMAINDER,)
cfg_terminal = parser.parse_args()
cfg = get_cfg(cfg_terminal)
run_pre(cfg)
init_training(cfg)
init_checkpoint(cfg)
trainer = get_trainer(cfg)
trainer.run()
del trainer
if __name__ == '__main__':
import torch
main()