-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval.py
More file actions
48 lines (38 loc) · 1.33 KB
/
eval.py
File metadata and controls
48 lines (38 loc) · 1.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
from argparse import Namespace
from typing import Callable
import sys
import gym
import numpy as np
import torch
from tensorboardX import SummaryWriter
import wandb
from core.ma_gym.eval_loop import gym_loop
from utils.logger import Logger
from utils.parser import EvalOptions
from utils.utils import fix_random
def main(args):
global best_result, _device
# Init loggers
if args.wandb:
wandb.init(group="CoMix", project="TrafficAD", entity="johnminelli")
if args.tensorboard:
tb_writer = SummaryWriter()
else: tb_writer = None
logger = Logger(valid=True, episodes=args.val_episodes, batch_size=1, terminal_print_freq=args.print_freq, tensorboard=tb_writer, wand=args.wandb)
# Set the seed
# fix_random(args.seed)
# Setup training devices
if args.gpu_ids[0] < 0 or not torch.cuda.is_available():
print("%s on CPU" % ("Training" if args.isTrain else "Executing"))
device = torch.device("cpu")
else:
print("%s on GPU" % ("Training" if args.isTrain else "Executing"))
if len(args.gpu_ids) > 1:
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids)[1:-1]
device = torch.device("cuda")
gym_loop(args, device, logger)
if __name__ == '__main__':
# Get arguments
parser_config = EvalOptions().parse()
main(parser_config)