-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbaselines.py
More file actions
114 lines (103 loc) · 3.79 KB
/
baselines.py
File metadata and controls
114 lines (103 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch as th
from functions import initialize_dqn, initialize_ppo, learning
import copy
from environment import MasterEnv
def vanilla(env, model_name, seed, args, algorithm="PPO"):
if algorithm == "DQN":
policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=args["network"])
model = initialize_dqn(
env,
num_worker=args["num_worker"],
seed=seed,
policy_kwargs=policy_kwargs,
learning_rate=args["learning_rate"],
gamma=args["gamma"],
learning_starts=args["learning_starts"],
batch_size=args["batch_size"],
gradient_steps=args["gradient_steps"],
target_update_interval=args["target_update_interval"],
tau=args["tau"],
)
else:
policy_kwargs = dict(
activation_fn=th.nn.ReLU, net_arch=dict(pi=args["network"], vf=args["vf"])
)
model = initialize_ppo(
env,
num_worker=args["num_worker"],
seed=seed,
policy_kwargs=policy_kwargs,
clip_range=args["clip_range"],
learning_rate=args["learning_rate"],
gamma=args["gamma"],
ent_coef=args["ent_coef"],
gae_lambda=args["gae_lambda"],
rollout_length=args["rollout_length"],
n_epochs=args["n_epochs"],
)
eval_env = copy.deepcopy(env)
model = learning(
model,
eval_env,
num_worker=args["num_worker"],
log_path=args["log_path"],
model_name=model_name,
time_steps=args["num_iterations"],
eval_freq=args["eval_freq"],
n_eval_episodes=args["n_eval_episodes"],
)
model.save(args["log_path"] + model_name + "_MODEL")
def training_tasks_learning(training_envs, seed, training_tasks_args):
print(training_tasks_args)
for i, env in enumerate(training_envs):
print(f"Task {i + 1}")
model_name = f"task{i + 1}_seed{seed}"
vanilla(env, model_name, seed, training_tasks_args)
def train_agent(env, programs, option_sizes, model_name, seed, args, algorithm="PPO"):
# set the environment
master_env = MasterEnv(env=env, option_sizes=option_sizes, options=programs)
if algorithm == "DQN":
policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=args["network"])
model = initialize_dqn(
master_env,
num_worker=args["num_worker"],
seed=seed,
policy_kwargs=policy_kwargs,
learning_rate=args["learning_rate"],
gamma=args["gamma"],
learning_starts=args["learning_starts"],
batch_size=args["batch_size"],
gradient_steps=args["gradient_steps"],
target_update_interval=args["target_update_interval"],
tau=args["tau"],
)
else:
policy_kwargs = dict(
activation_fn=th.nn.ReLU, net_arch=dict(pi=args["network"], vf=args["vf"])
)
model = initialize_ppo(
master_env,
num_worker=args["num_worker"],
seed=seed,
policy_kwargs=policy_kwargs,
clip_range=args["clip_range"],
learning_rate=args["learning_rate"],
gamma=args["gamma"],
ent_coef=args["ent_coef"],
gae_lambda=args["gae_lambda"],
rollout_length=args["rollout_length"],
n_epochs=args["n_epochs"],
)
eval_env = copy.deepcopy(master_env)
model = learning(
model,
eval_env,
num_worker=args["num_worker"],
log_path=args["log_path"],
model_name=model_name,
time_steps=args["num_iterations"],
eval_freq=args["eval_freq"],
n_eval_episodes=args["n_eval_episodes"],
)
## save model
model.save(args["log_path"] + model_name + "_MASTER")