-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
55 lines (47 loc) · 1.67 KB
/
train.py
File metadata and controls
55 lines (47 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
from base.base_config import BaseConfig
from base.base_trainer import build_trainer
from trainers.toy_trainer import ToyTrainer
from trainers.diffposetalk_trainer import StyleEncoderTrainer, DiffPoseTalkTrainer
import warnings
warnings.filterwarnings('ignore')
def main(args):
base_cfg = BaseConfig()
base_cfg.cfg.merge_from_file(args.config_file)
# From optional input arguments
base_cfg.cfg.merge_from_list(args.opts)
# frozen the trainer config
base_cfg.cfg.freeze()
trainer = build_trainer(base_cfg.cfg)
if args.mode == "eval":
trainer.load_model(args.model_dir, epoch=args.load_epoch)
trainer.test()
elif args.mode == "analysis":
trainer.dm.data_analysis()
elif args.mode == "train":
trainer.train()
else:
raise ValueError(f"Unknown mode: {args.mode}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--config-file', type=str, default='config/codetalker/vocaset/stage1.yaml', help='path to config file'
)
parser.add_argument(
'--mode', type=str, choices=['train', 'eval', 'analysis'],
default='train', help='Operation mode: train, eval, or analysis'
)
parser.add_argument('--debug', action='store_true', help='wether do debugging')
parser.add_argument(
'opts',
default=None,
nargs=argparse.REMAINDER,
help='modify config options using the command-line'
)
args = parser.parse_args()
if args.debug:
import debugpy
debugpy.listen(6666)
print("Waiting for debugger attach (rank 0)...")
debugpy.wait_for_client()
main(args)