Skip to content

Commit cfaaf15

Browse files
authored
Merge pull request #161 from edbeeching/add_clean_rl_onnx_export
Adds onnx export to cleanrl example
2 parents 69495f4 + 034ce4f commit cfaaf15

File tree

3 files changed

+73
-260
lines changed

3 files changed

+73
-260
lines changed

docs/ADV_CLEAN_RL.md

Lines changed: 29 additions & 258 deletions
Original file line numberDiff line numberDiff line change
@@ -23,268 +23,39 @@ pip install godot-rl[cleanrl]
2323
While the default options for cleanrl work reasonably well. You may be interested in changing the hyperparameters.
2424
We recommend taking the [cleanrl example](https://github.com/edbeeching/godot_rl_agents/blob/main/examples/clean_rl_example.py) and modifying to match your needs.
2525

26-
```python
27-
parser.add_argument("--gae-lambda", type=float, default=0.95,
28-
help="the lambda for the general advantage estimation")
29-
parser.add_argument("--num-minibatches", type=int, default=8,
30-
help="the number of mini-batches")
31-
parser.add_argument("--update-epochs", type=int, default=10,
32-
help="the K epochs to update the policy")
33-
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
34-
help="Toggles advantages normalization")
35-
parser.add_argument("--clip-coef", type=float, default=0.2,
36-
help="the surrogate clipping coefficient")
37-
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
38-
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
39-
parser.add_argument("--ent-coef", type=float, default=0.0001,
40-
help="coefficient of the entropy")
41-
parser.add_argument("--vf-coef", type=float, default=0.5,
42-
help="coefficient of the value function")
43-
parser.add_argument("--max-grad-norm", type=float, default=0.5,
44-
help="the maximum norm for the gradient clipping")
45-
parser.add_argument("--target-kl", type=float, default=None,
46-
help="the target KL divergence threshold")
47-
args = parser.parse_args()
26+
## CleanRL Example script usage:
27+
To use the example script, first move to the location where the downloaded script is in the console/terminal, and then try some of the example use cases below:
4828

49-
# fmt: on
50-
return args
51-
52-
53-
def make_env(env_path, speedup):
54-
def thunk():
55-
env = CleanRLGodotEnv(env_path=env_path, show_window=True, speedup=speedup)
56-
return env
57-
return thunk
58-
59-
60-
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
61-
torch.nn.init.orthogonal_(layer.weight, std)
62-
torch.nn.init.constant_(layer.bias, bias_const)
63-
return layer
64-
65-
66-
class Agent(nn.Module):
67-
def __init__(self, envs):
68-
super().__init__()
69-
self.critic = nn.Sequential(
70-
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
71-
nn.Tanh(),
72-
layer_init(nn.Linear(64, 64)),
73-
nn.Tanh(),
74-
layer_init(nn.Linear(64, 1), std=1.0),
75-
)
76-
self.actor_mean = nn.Sequential(
77-
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
78-
nn.Tanh(),
79-
layer_init(nn.Linear(64, 64)),
80-
nn.Tanh(),
81-
layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
82-
)
83-
self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))
84-
85-
def get_value(self, x):
86-
return self.critic(x)
87-
88-
def get_action_and_value(self, x, action=None):
89-
action_mean = self.actor_mean(x)
90-
action_logstd = self.actor_logstd.expand_as(action_mean)
91-
action_std = torch.exp(action_logstd)
92-
probs = Normal(action_mean, action_std)
93-
if action is None:
94-
action = probs.sample()
95-
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
96-
97-
98-
if __name__ == "__main__":
99-
args = parse_args()
100-
run_name = f"{args.env_path}__{args.exp_name}__{args.seed}__{int(time.time())}"
101-
if args.track:
102-
import wandb
103-
104-
wandb.init(
105-
project=args.wandb_project_name,
106-
entity=args.wandb_entity,
107-
sync_tensorboard=True,
108-
config=vars(args),
109-
name=run_name,
110-
# monitor_gym=True, no longer works for gymnasium
111-
save_code=True,
112-
)
113-
writer = SummaryWriter(f"runs/{run_name}")
114-
writer.add_text(
115-
"hyperparameters",
116-
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
117-
)
118-
119-
# TRY NOT TO MODIFY: seeding
120-
random.seed(args.seed)
121-
np.random.seed(args.seed)
122-
torch.manual_seed(args.seed)
123-
torch.backends.cudnn.deterministic = args.torch_deterministic
124-
125-
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
126-
127-
# env setup
128-
129-
envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=True, speedup=args.speedup, convert_action_space=True) # Godot envs are already vectorized
130-
#assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
131-
args.num_envs = envs.num_envs
132-
args.batch_size = int(args.num_envs * args.num_steps)
133-
args.minibatch_size = int(args.batch_size // args.num_minibatches)
134-
agent = Agent(envs).to(device)
135-
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
136-
137-
# ALGO Logic: Storage setup
138-
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
139-
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
140-
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
141-
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
142-
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
143-
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
144-
145-
# TRY NOT TO MODIFY: start the game
146-
global_step = 0
147-
start_time = time.time()
148-
next_obs, _ = envs.reset(seed=args.seed)
149-
next_obs = torch.Tensor(next_obs).to(device)
150-
next_done = torch.zeros(args.num_envs).to(device)
151-
num_updates = args.total_timesteps // args.batch_size
152-
video_filenames = set()
153-
154-
# episode reward stats, modified as Godot RL does not return this information in info (yet)
155-
episode_returns = deque(maxlen=20)
156-
accum_rewards = np.zeros(args.num_envs)
157-
158-
for update in range(1, num_updates + 1):
159-
# Annealing the rate if instructed to do so.
160-
if args.anneal_lr:
161-
frac = 1.0 - (update - 1.0) / num_updates
162-
lrnow = frac * args.learning_rate
163-
optimizer.param_groups[0]["lr"] = lrnow
164-
165-
for step in range(0, args.num_steps):
166-
global_step += 1 * args.num_envs
167-
obs[step] = next_obs
168-
dones[step] = next_done
169-
170-
# ALGO LOGIC: action logic
171-
with torch.no_grad():
172-
action, logprob, _, value = agent.get_action_and_value(next_obs)
173-
values[step] = value.flatten()
174-
actions[step] = action
175-
logprobs[step] = logprob
176-
177-
# TRY NOT TO MODIFY: execute the game and log data.
178-
next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy())
179-
done = np.logical_or(terminated, truncated)
180-
rewards[step] = torch.tensor(reward).to(device).view(-1)
181-
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
182-
183-
accum_rewards += np.array(reward)
184-
185-
for i, d in enumerate(done):
186-
if d:
187-
episode_returns.append(accum_rewards[i])
188-
accum_rewards[i] = 0
189-
190-
# bootstrap value if not done
191-
with torch.no_grad():
192-
next_value = agent.get_value(next_obs).reshape(1, -1)
193-
advantages = torch.zeros_like(rewards).to(device)
194-
lastgaelam = 0
195-
for t in reversed(range(args.num_steps)):
196-
if t == args.num_steps - 1:
197-
nextnonterminal = 1.0 - next_done
198-
nextvalues = next_value
199-
else:
200-
nextnonterminal = 1.0 - dones[t + 1]
201-
nextvalues = values[t + 1]
202-
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
203-
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
204-
returns = advantages + values
205-
206-
# flatten the batch
207-
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
208-
b_logprobs = logprobs.reshape(-1)
209-
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
210-
b_advantages = advantages.reshape(-1)
211-
b_returns = returns.reshape(-1)
212-
b_values = values.reshape(-1)
213-
214-
# Optimizing the policy and value network
215-
b_inds = np.arange(args.batch_size)
216-
clipfracs = []
217-
for epoch in range(args.update_epochs):
218-
np.random.shuffle(b_inds)
219-
for start in range(0, args.batch_size, args.minibatch_size):
220-
end = start + args.minibatch_size
221-
mb_inds = b_inds[start:end]
222-
223-
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
224-
logratio = newlogprob - b_logprobs[mb_inds]
225-
ratio = logratio.exp()
226-
227-
with torch.no_grad():
228-
# calculate approx_kl http://joschu.net/blog/kl-approx.html
229-
old_approx_kl = (-logratio).mean()
230-
approx_kl = ((ratio - 1) - logratio).mean()
231-
clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
232-
233-
mb_advantages = b_advantages[mb_inds]
234-
if args.norm_adv:
235-
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
236-
237-
# Policy loss
238-
pg_loss1 = -mb_advantages * ratio
239-
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
240-
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
241-
242-
# Value loss
243-
newvalue = newvalue.view(-1)
244-
if args.clip_vloss:
245-
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
246-
v_clipped = b_values[mb_inds] + torch.clamp(
247-
newvalue - b_values[mb_inds],
248-
-args.clip_coef,
249-
args.clip_coef,
250-
)
251-
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
252-
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
253-
v_loss = 0.5 * v_loss_max.mean()
254-
else:
255-
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
256-
257-
entropy_loss = entropy.mean()
258-
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
29+
### Train a model in editor:
30+
```bash
31+
python clean_rl_example.py
32+
```
25933

260-
optimizer.zero_grad()
261-
loss.backward()
262-
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
263-
optimizer.step()
34+
### Train a model using an exported environment:
35+
```bash
36+
python clean_rl_example.py --env_path=path_to_executable
37+
```
38+
Note that the exported environment will not be rendered in order to accelerate training.
39+
If you want to display it, add the `--viz` argument.
26440

265-
if args.target_kl is not None:
266-
if approx_kl > args.target_kl:
267-
break
41+
### Train an exported environment using 4 environment processes:
42+
```bash
43+
python clean_rl_example.py --env_path=path_to_executable --n_parallel=4
44+
```
26845

269-
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
270-
var_y = np.var(y_true)
271-
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
46+
### Train an exported environment using 8 times speedup:
47+
```bash
48+
python clean_rl_example.py --env_path=path_to_executable --speedup=8
49+
```
27250

273-
# TRY NOT TO MODIFY: record rewards for plotting purposes
274-
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
275-
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
276-
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
277-
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
278-
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
279-
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
280-
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
281-
writer.add_scalar("losses/explained_variance", explained_var, global_step)
282-
if len(episode_returns) > 0:
283-
print("SPS:", int(global_step / (time.time() - start_time)), "Returns:", np.mean(np.array(episode_returns)))
284-
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
285-
writer.add_scalar("charts/episodic_return", np.mean(np.array(episode_returns)), global_step)
51+
### Set an experiment directory and name:
52+
```bash
53+
python clean_rl_example.py --experiment_dir="experiments" --experiment_name="experiment1"
54+
```
28655

287-
envs.close()
288-
writer.close()
56+
### Train a model for 100_000 steps then export the model to onnx (can be used for inference in Godot, including in exported games - tested on only some platforms for now):
57+
```bash
58+
python clean_rl_example.py --total-timesteps=100_000 --onnx_export_path=model.onnx
59+
```
28960

290-
```
61+
There are many other command line arguments defined in the [cleanrl example](https://github.com/edbeeching/godot_rl_agents/blob/main/examples/clean_rl_example.py) file.

examples/clean_rl_example.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy
22
import argparse
33
import os
4+
import pathlib
45
import random
56
import time
67
from distutils.util import strtobool
@@ -39,6 +40,12 @@ def parse_args():
3940
help="the entity (team) of wandb's project")
4041
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
4142
help="whether to capture videos of the agent performances (check out `videos` folder)")
43+
parser.add_argument(
44+
"--onnx_export_path",
45+
default=None,
46+
type=str,
47+
help="If included, will export onnx file after training to the path specified."
48+
)
4249

4350
# Algorithm specific arguments
4451
parser.add_argument("--env_path", type=str, default=None,
@@ -160,7 +167,8 @@ def get_action_and_value(self, x, action=None):
160167

161168
# env setup
162169

163-
envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed, n_parallel=args.n_parallel)
170+
envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed,
171+
n_parallel=args.n_parallel)
164172
args.num_envs = envs.num_envs
165173
args.batch_size = int(args.num_envs * args.num_steps)
166174
args.minibatch_size = int(args.batch_size // args.num_minibatches)
@@ -319,3 +327,37 @@ def get_action_and_value(self, x, action=None):
319327

320328
envs.close()
321329
writer.close()
330+
331+
if args.onnx_export_path is not None:
332+
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
333+
print("Exporting onnx to: " + os.path.abspath(path_onnx))
334+
335+
agent.eval().to("cpu")
336+
337+
338+
class OnnxPolicy(torch.nn.Module):
339+
def __init__(self, actor_mean):
340+
super().__init__()
341+
self.actor_mean = actor_mean
342+
343+
def forward(self, obs, state_ins):
344+
action_mean = self.actor_mean(obs)
345+
return action_mean, state_ins
346+
347+
348+
onnx_policy = OnnxPolicy(agent.actor_mean)
349+
dummy_input = torch.unsqueeze(torch.tensor(envs.single_observation_space.sample()), 0)
350+
351+
torch.onnx.export(
352+
onnx_policy,
353+
args=(dummy_input, torch.zeros(1).float()),
354+
f=str(path_onnx),
355+
opset_version=15,
356+
input_names=["obs", "state_ins"],
357+
output_names=["output", "state_outs"],
358+
dynamic_axes={'obs': {0: 'batch_size'},
359+
'state_ins': {0: 'batch_size'}, # variable length axes
360+
'output': {0: 'batch_size'},
361+
'state_outs': {0: 'batch_size'}}
362+
363+
)

examples/stable_baselines3_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
"--onnx_export_path",
7070
default=None,
7171
type=str,
72-
help="The Godot binary to use, do not include for in editor training",
72+
help="If included, will export onnx file after training to the path specified.",
7373
)
7474
parser.add_argument(
7575
"--timesteps",

0 commit comments

Comments
 (0)