Skip to content

Commit 912169d

Browse files
committed
Adds PQN onnx export
1 parent 40d3770 commit 912169d

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

examples/clean_rl_pqn_example.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/pqn/#pqnpy
55
import os
6+
import pathlib
67
import random
78
import time
89
from dataclasses import dataclass
@@ -21,6 +22,8 @@
2122

2223
@dataclass
2324
class Args:
25+
onnx_export_path: str = None
26+
"""If set, will export onnx to this path after training is done"""
2427
env_path: str = None
2528
"""Path to the Godot exported environment"""
2629
n_parallel: int = 1
@@ -258,9 +261,41 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
258261

259262
writer.add_scalar("losses/td_loss", loss, global_step)
260263
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
261-
print("SPS:", int(global_step / (time.time() - start_time)))
262-
print("epsilon:", epsilon)
264+
print(f"SPS: {int(global_step / (time.time() - start_time))}, Epsilon: {epsilon}")
263265
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
264266

265267
envs.close()
266268
writer.close()
269+
270+
if args.onnx_export_path is not None:
271+
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
272+
print("Exporting onnx to: " + os.path.abspath(path_onnx))
273+
274+
q_network.eval().to("cpu")
275+
276+
class OnnxPolicy(torch.nn.Module):
277+
def __init__(self, network):
278+
super().__init__()
279+
self.network = network
280+
281+
def forward(self, onnx_obs, state_ins):
282+
network_output = self.network(onnx_obs)
283+
return network_output, state_ins
284+
285+
onnx_policy = OnnxPolicy(q_network.network)
286+
dummy_input = torch.unsqueeze(torch.tensor(envs.single_observation_space.sample()), 0)
287+
288+
torch.onnx.export(
289+
onnx_policy,
290+
args=(dummy_input, torch.zeros(1).float()),
291+
f=str(path_onnx),
292+
opset_version=15,
293+
input_names=["obs", "state_ins"],
294+
output_names=["output", "state_outs"],
295+
dynamic_axes={
296+
"obs": {0: "batch_size"},
297+
"state_ins": {0: "batch_size"}, # variable length axes
298+
"output": {0: "batch_size"},
299+
"state_outs": {0: "batch_size"},
300+
},
301+
)

0 commit comments

Comments
 (0)