-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmain.py
More file actions
37 lines (33 loc) · 1.72 KB
/
main.py
File metadata and controls
37 lines (33 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import os, sys
sys.path.append(os.path.dirname(__file__))
if __name__ == "__main__":
import app.app_utils
from wisp.trainers import *
from config_parser import parse_options, argparse_to_str, get_modules_from_config, \
get_optimizer_from_config
from wisp.framework import WispState
# Usual boilerplate
parser = parse_options(return_parser=True)
app.app_utils.add_log_level_flag(parser)
app_group = parser.add_argument_group('app')
# Add custom args if needed for app
args, args_str = argparse_to_str(parser)
app.app_utils.default_log_setup(args.log_level)
pipeline, train_dataset, device = get_modules_from_config(args)
optim_cls, optim_params = get_optimizer_from_config(args)
trainer = globals()[args.trainer_type](pipeline, train_dataset, args.epochs, args.batch_size,
optim_cls, args.lr, args.weight_decay,
args.grid_lr_weight, optim_params, args.log_dir, device,
exp_name=args.exp_name, info=args_str, extra_args=vars(args),
render_every=args.render_every, save_every=args.save_every)
if args.valid_only:
trainer.validate()
else:
trainer.train()