@@ -167,7 +167,8 @@ def get_action_and_value(self, x, action=None):
167167
168168 # env setup
169169
170- envs = env = CleanRLGodotEnv (env_path = args .env_path , show_window = args .viz , speedup = args .speedup , seed = args .seed , n_parallel = args .n_parallel )
170+ envs = env = CleanRLGodotEnv (env_path = args .env_path , show_window = args .viz , speedup = args .speedup , seed = args .seed ,
171+ n_parallel = args .n_parallel )
171172 args .num_envs = envs .num_envs
172173 args .batch_size = int (args .num_envs * args .num_steps )
173174 args .minibatch_size = int (args .batch_size // args .num_minibatches )
@@ -333,6 +334,7 @@ def get_action_and_value(self, x, action=None):
333334
334335 agent .eval ().to ("cpu" )
335336
337+
336338 class OnnxPolicy (torch .nn .Module ):
337339 def __init__ (self , actor_mean ):
338340 super ().__init__ ()
@@ -342,6 +344,7 @@ def forward(self, obs, state_ins):
342344 action_mean = self .actor_mean (obs )
343345 return action_mean , state_ins
344346
347+
345348 onnx_policy = OnnxPolicy (agent .actor_mean )
346349 dummy_input = torch .unsqueeze (torch .tensor (envs .single_observation_space .sample ()), 0 )
347350
@@ -352,9 +355,9 @@ def forward(self, obs, state_ins):
352355 opset_version = 15 ,
353356 input_names = ["obs" , "state_ins" ],
354357 output_names = ["output" , "state_outs" ],
355- dynamic_axes = {'obs' : {0 : 'batch_size' },
356- 'state_ins' : {0 : 'batch_size' }, # variable length axes
357- 'output' : {0 : 'batch_size' },
358- 'state_outs' : {0 : 'batch_size' }}
358+ dynamic_axes = {'obs' : {0 : 'batch_size' },
359+ 'state_ins' : {0 : 'batch_size' }, # variable length axes
360+ 'output' : {0 : 'batch_size' },
361+ 'state_outs' : {0 : 'batch_size' }}
359362
360- )
363+ )
0 commit comments