33
44# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/pqn/#pqnpy
55import os
6+ import pathlib
67import random
78import time
89from dataclasses import dataclass
2122
2223@dataclass
2324class Args :
25+ onnx_export_path : str = None
26+ """If set, will export onnx to this path after training is done"""
2427 env_path : str = None
2528 """Path to the Godot exported environment"""
2629 n_parallel : int = 1
@@ -258,9 +261,41 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
258261
259262 writer .add_scalar ("losses/td_loss" , loss , global_step )
260263 writer .add_scalar ("losses/q_values" , old_val .mean ().item (), global_step )
261- print ("SPS:" , int (global_step / (time .time () - start_time )))
262- print ("epsilon:" , epsilon )
264+ print (f"SPS: { int (global_step / (time .time () - start_time ))} , Epsilon: { epsilon } " )
263265 writer .add_scalar ("charts/SPS" , int (global_step / (time .time () - start_time )), global_step )
264266
265267 envs .close ()
266268 writer .close ()
269+
270+ if args .onnx_export_path is not None :
271+ path_onnx = pathlib .Path (args .onnx_export_path ).with_suffix (".onnx" )
272+ print ("Exporting onnx to: " + os .path .abspath (path_onnx ))
273+
274+ q_network .eval ().to ("cpu" )
275+
276+ class OnnxPolicy (torch .nn .Module ):
277+ def __init__ (self , network ):
278+ super ().__init__ ()
279+ self .network = network
280+
281+ def forward (self , onnx_obs , state_ins ):
282+ network_output = self .network (onnx_obs )
283+ return network_output , state_ins
284+
285+ onnx_policy = OnnxPolicy (q_network .network )
286+ dummy_input = torch .unsqueeze (torch .tensor (envs .single_observation_space .sample ()), 0 )
287+
288+ torch .onnx .export (
289+ onnx_policy ,
290+ args = (dummy_input , torch .zeros (1 ).float ()),
291+ f = str (path_onnx ),
292+ opset_version = 15 ,
293+ input_names = ["obs" , "state_ins" ],
294+ output_names = ["output" , "state_outs" ],
295+ dynamic_axes = {
296+ "obs" : {0 : "batch_size" },
297+ "state_ins" : {0 : "batch_size" }, # variable length axes
298+ "output" : {0 : "batch_size" },
299+ "state_outs" : {0 : "batch_size" },
300+ },
301+ )
0 commit comments