Skip to content

Commit 12ab0d4

Browse files
authored
Add Parallel Q Network (PQN) cleanrl example
[Ready to merge] Add Parallel Q Network (PQN) cleanrl example and single-discrete action support
2 parents f4b5392 + a0bfa10 commit 12ab0d4

File tree

4 files changed

+325
-4
lines changed

4 files changed

+325
-4
lines changed

examples/clean_rl_pqn_example.py

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
# Original file taken from CleanRL https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/pqn.py
2+
# and adapted to work with Godot RL Agents envs
3+
4+
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/pqn/#pqnpy
5+
import os
6+
import pathlib
7+
import random
8+
import time
9+
from collections import deque
10+
from dataclasses import dataclass
11+
12+
import gymnasium as gym
13+
import numpy as np
14+
import torch
15+
import torch.nn as nn
16+
import torch.nn.functional as F
17+
import torch.optim as optim
18+
import tyro
19+
from torch.utils.tensorboard import SummaryWriter
20+
21+
from godot_rl.wrappers.clean_rl_wrapper import CleanRLGodotEnv
22+
23+
24+
@dataclass
25+
class Args:
26+
onnx_export_path: str = None
27+
"""If set, will export onnx to this path after training is done"""
28+
env_path: str = None
29+
"""Path to the Godot exported environment"""
30+
n_parallel: int = 1
31+
"""How many instances of the environment executable to
32+
launch (requires --env_path to be set if > 1)."""
33+
viz: bool = False
34+
"""Whether the exported Godot environment will displayed during training"""
35+
speedup: int = 8
36+
"""How much to speed up the environment"""
37+
exp_name: str = os.path.basename(__file__)[: -len(".py")]
38+
"""the name of this experiment"""
39+
seed: int = 1
40+
"""seed of the experiment"""
41+
torch_deterministic: bool = True
42+
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
43+
cuda: bool = True
44+
"""if toggled, cuda will be enabled by default"""
45+
track: bool = False
46+
"""if toggled, this experiment will be tracked with Weights and Biases"""
47+
wandb_project_name: str = "cleanRL"
48+
"""the wandb's project name"""
49+
wandb_entity: str = None
50+
"""the entity (team) of wandb's project"""
51+
capture_video: bool = False
52+
"""whether to capture videos of the agent performances (check out `videos` folder)"""
53+
54+
# Algorithm specific arguments
55+
env_id: str = "CartPole-v1"
56+
"""the id of the environment"""
57+
total_timesteps: int = 1_000_000
58+
"""total timesteps of the experiments"""
59+
learning_rate: float = 2.5e-4
60+
"""the learning rate of the optimizer [note: automatically set]"""
61+
num_envs: int = 4
62+
"""the number of parallel game environments"""
63+
num_steps: int = 128
64+
"""the number of steps to run for each environment per update"""
65+
num_minibatches: int = 4
66+
"""the number of mini-batches"""
67+
update_epochs: int = 4
68+
"""the K epochs to update the policy"""
69+
anneal_lr: bool = True
70+
"""Toggle learning rate annealing"""
71+
gamma: float = 0.99
72+
"""the discount factor gamma"""
73+
start_e: float = 1
74+
"""the starting epsilon for exploration"""
75+
end_e: float = 0.05
76+
"""the ending epsilon for exploration"""
77+
exploration_fraction: float = 0.5
78+
"""the fraction of `total_timesteps` it takes from start_e to end_e"""
79+
max_grad_norm: float = 10.0
80+
"""the maximum norm for the gradient clipping"""
81+
q_lambda: float = 0.65
82+
"""the lambda for Q(lambda)"""
83+
84+
85+
# def make_env(env_path):
86+
# def thunk():
87+
# env = CleanRLGodotEnv(env_path=env_path, show_window=True)
88+
# return env
89+
#
90+
# return thunk
91+
92+
93+
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
94+
torch.nn.init.orthogonal_(layer.weight, std)
95+
torch.nn.init.constant_(layer.bias, bias_const)
96+
return layer
97+
98+
99+
# ALGO LOGIC: initialize agent here:
100+
class QNetwork(nn.Module):
101+
def __init__(self, envs):
102+
super().__init__()
103+
104+
self.network = nn.Sequential(
105+
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 120)),
106+
nn.LayerNorm(120),
107+
nn.ReLU(),
108+
layer_init(nn.Linear(120, 84)),
109+
nn.LayerNorm(84),
110+
nn.ReLU(),
111+
layer_init(nn.Linear(84, env.single_action_space.n)),
112+
)
113+
114+
def forward(self, x):
115+
return self.network(x)
116+
117+
118+
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
119+
slope = (end_e - start_e) / duration
120+
return max(slope * t + start_e, end_e)
121+
122+
123+
if __name__ == "__main__":
124+
args = tyro.cli(Args)
125+
126+
# env setup
127+
envs = env = CleanRLGodotEnv(
128+
env_path=args.env_path,
129+
show_window=args.viz,
130+
speedup=args.speedup,
131+
seed=args.seed,
132+
n_parallel=args.n_parallel,
133+
)
134+
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
135+
136+
args.num_envs = envs.num_envs
137+
args.batch_size = int(args.num_envs * args.num_steps)
138+
args.minibatch_size = int(args.batch_size // args.num_minibatches)
139+
args.num_iterations = args.total_timesteps // args.batch_size
140+
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
141+
if args.track:
142+
import wandb
143+
144+
wandb.init(
145+
project=args.wandb_project_name,
146+
entity=args.wandb_entity,
147+
sync_tensorboard=True,
148+
config=vars(args),
149+
name=run_name,
150+
monitor_gym=True,
151+
save_code=True,
152+
)
153+
writer = SummaryWriter(f"runs/{run_name}")
154+
writer.add_text(
155+
"hyperparameters",
156+
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
157+
)
158+
159+
# TRY NOT TO MODIFY: seeding
160+
random.seed(args.seed)
161+
np.random.seed(args.seed)
162+
torch.manual_seed(args.seed)
163+
torch.backends.cudnn.deterministic = args.torch_deterministic
164+
165+
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
166+
167+
# agent setup
168+
q_network = QNetwork(envs).to(device)
169+
optimizer = optim.RAdam(q_network.parameters(), lr=args.learning_rate)
170+
171+
# storage setup
172+
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
173+
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
174+
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
175+
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
176+
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
177+
178+
# TRY NOT TO MODIFY: start the game
179+
global_step = 0
180+
start_time = time.time()
181+
next_obs, _ = envs.reset(seed=args.seed)
182+
next_obs = torch.Tensor(next_obs).to(device)
183+
next_done = torch.zeros(args.num_envs).to(device)
184+
185+
# episode reward stats, modified as Godot RL does not return this information in info (yet)
186+
episode_returns = deque(maxlen=20)
187+
accum_rewards = np.zeros(args.num_envs)
188+
189+
for iteration in range(1, args.num_iterations + 1):
190+
# Annealing the rate if instructed to do so.
191+
if args.anneal_lr:
192+
frac = 1.0 - (iteration - 1.0) / args.num_iterations
193+
lrnow = frac * args.learning_rate
194+
optimizer.param_groups[0]["lr"] = lrnow
195+
196+
for step in range(0, args.num_steps):
197+
global_step += args.num_envs
198+
obs[step] = next_obs
199+
dones[step] = next_done
200+
201+
epsilon = linear_schedule(
202+
args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step
203+
)
204+
random_actions = torch.randint(0, envs.single_action_space.n, (args.num_envs,)).to(device)
205+
with torch.no_grad():
206+
q_values = q_network(next_obs)
207+
max_actions = torch.argmax(q_values, dim=1)
208+
values[step] = q_values[torch.arange(args.num_envs), max_actions].flatten()
209+
210+
explore = torch.rand((args.num_envs,)).to(device) < epsilon
211+
action = torch.where(explore, random_actions, max_actions)
212+
actions[step] = action
213+
214+
# TRY NOT TO MODIFY: execute the game and log data.
215+
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
216+
next_done = np.logical_or(terminations, truncations)
217+
rewards[step] = torch.tensor(reward).to(device).view(-1)
218+
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
219+
220+
accum_rewards += np.array(reward)
221+
222+
for i, d in enumerate(next_done):
223+
if d:
224+
episode_returns.append(accum_rewards[i])
225+
accum_rewards[i] = 0
226+
227+
# Compute Q(lambda) targets
228+
with torch.no_grad():
229+
returns = torch.zeros_like(rewards).to(device)
230+
for t in reversed(range(args.num_steps)):
231+
if t == args.num_steps - 1:
232+
next_value, _ = torch.max(q_network(next_obs), dim=-1)
233+
nextnonterminal = 1.0 - next_done
234+
returns[t] = rewards[t] + args.gamma * next_value * nextnonterminal
235+
else:
236+
nextnonterminal = 1.0 - dones[t + 1]
237+
next_value = values[t + 1]
238+
returns[t] = (
239+
rewards[t]
240+
+ args.gamma
241+
* (args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value)
242+
* nextnonterminal
243+
)
244+
245+
# flatten the batch
246+
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
247+
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
248+
b_returns = returns.reshape(-1)
249+
250+
# Optimizing the Q-network
251+
b_inds = np.arange(args.batch_size)
252+
for epoch in range(args.update_epochs):
253+
np.random.shuffle(b_inds)
254+
for start in range(0, args.batch_size, args.minibatch_size):
255+
end = start + args.minibatch_size
256+
mb_inds = b_inds[start:end]
257+
258+
old_val = q_network(b_obs[mb_inds]).gather(1, b_actions[mb_inds].unsqueeze(-1).long()).squeeze()
259+
loss = F.mse_loss(b_returns[mb_inds], old_val)
260+
261+
# optimize the model
262+
optimizer.zero_grad()
263+
loss.backward()
264+
nn.utils.clip_grad_norm_(q_network.parameters(), args.max_grad_norm)
265+
optimizer.step()
266+
267+
writer.add_scalar("losses/td_loss", loss, global_step)
268+
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
269+
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
270+
print(f"SPS: {int(global_step / (time.time() - start_time))}, Epsilon (rand action prob): {epsilon}")
271+
if len(episode_returns) > 0:
272+
print(
273+
"Returns:",
274+
np.mean(np.array(episode_returns)),
275+
)
276+
writer.add_scalar("charts/episodic_return", np.mean(np.array(episode_returns)), global_step)
277+
278+
envs.close()
279+
writer.close()
280+
281+
if args.onnx_export_path is not None:
282+
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
283+
print("Exporting onnx to: " + os.path.abspath(path_onnx))
284+
285+
q_network.eval().to("cpu")
286+
287+
class OnnxPolicy(torch.nn.Module):
288+
def __init__(self, network):
289+
super().__init__()
290+
self.network = network
291+
292+
def forward(self, onnx_obs, state_ins):
293+
network_output = self.network(onnx_obs)
294+
return network_output, state_ins
295+
296+
onnx_policy = OnnxPolicy(q_network.network)
297+
dummy_input = torch.unsqueeze(torch.tensor(envs.single_observation_space.sample()), 0)
298+
299+
torch.onnx.export(
300+
onnx_policy,
301+
args=(dummy_input, torch.zeros(1).float()),
302+
f=str(path_onnx),
303+
opset_version=15,
304+
input_names=["obs", "state_ins"],
305+
output_names=["output", "state_outs"],
306+
dynamic_axes={
307+
"obs": {0: "batch_size"},
308+
"state_ins": {0: "batch_size"}, # variable length axes
309+
"output": {0: "batch_size"},
310+
"state_outs": {0: "batch_size"},
311+
},
312+
)

examples/stable_baselines3_hp_tuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" Optuna example that optimizes the hyperparameters of
1+
"""Optuna example that optimizes the hyperparameters of
22
a reinforcement learning agent using PPO implementation from Stable-Baselines3
33
on a Gymnasium environment.
44

godot_rl/core/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,21 @@ class ActionSpaceProcessor:
3737
def __init__(self, action_space: gym.spaces.Tuple, convert) -> None:
3838
self._original_action_space = action_space
3939
self._convert = convert
40+
self._all_actions_discrete: bool = all(isinstance(space, gym.spaces.Discrete) for space in action_space.spaces)
41+
self._only_one_action_space: bool = len(action_space) == 1
42+
43+
# For DQN or other similar algorithms, we need a single discrete action space
44+
if self._only_one_action_space and self._all_actions_discrete:
45+
self.converted_action_space = action_space[0]
46+
return
4047

4148
space_size = 0
4249

4350
if convert:
4451
use_multi_discrete_spaces = False
4552
multi_discrete_spaces = np.array([])
4653
if isinstance(action_space, gym.spaces.Tuple):
47-
if all(isinstance(space, gym.spaces.Discrete) for space in action_space.spaces):
54+
if self._all_actions_discrete:
4855
use_multi_discrete_spaces = True
4956
for space in action_space.spaces:
5057
multi_discrete_spaces = np.append(multi_discrete_spaces, space.n)
@@ -84,6 +91,8 @@ def action_space(self):
8491
def to_original_dist(self, action):
8592
if not self._convert:
8693
return action
94+
elif self._only_one_action_space and self._all_actions_discrete:
95+
return [action]
8796

8897
original_action = []
8998
counter = 0

godot_rl/wrappers/onnx/stable_baselines_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ def export_model_as_onnx(model, onnx_model_path: str, use_obs_array: bool = Fals
9494
# We only verify with PPO currently due to different output shape with SAC
9595
# (this can be updated in the future)
9696
if isinstance(model, PPO):
97-
# If the space is MultiDiscrete, we skip verifying as action output will have an expected mismatch
97+
# If the space is Discrete/MultiDiscrete, we skip verifying as action output will have an expected mismatch
9898
# (the output from onnx will be the action logits for each discrete action,
9999
# while the output from sb3 will be a single int)
100-
if not isinstance(model.action_space, spaces.MultiDiscrete):
100+
if not isinstance(model.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
101101
verify_onnx_export(model, onnx_model_path, use_obs_array=use_obs_array)
102102

103103

0 commit comments

Comments
 (0)