Skip to content

Commit 034ce4f

Browse files
committed
Reformat clean_rl_example.py
1 parent e023c0f commit 034ce4f

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

examples/clean_rl_example.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def get_action_and_value(self, x, action=None):
167167

168168
# env setup
169169

170-
envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed, n_parallel=args.n_parallel)
170+
envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, seed=args.seed,
171+
n_parallel=args.n_parallel)
171172
args.num_envs = envs.num_envs
172173
args.batch_size = int(args.num_envs * args.num_steps)
173174
args.minibatch_size = int(args.batch_size // args.num_minibatches)
@@ -333,6 +334,7 @@ def get_action_and_value(self, x, action=None):
333334

334335
agent.eval().to("cpu")
335336

337+
336338
class OnnxPolicy(torch.nn.Module):
337339
def __init__(self, actor_mean):
338340
super().__init__()
@@ -342,6 +344,7 @@ def forward(self, obs, state_ins):
342344
action_mean = self.actor_mean(obs)
343345
return action_mean, state_ins
344346

347+
345348
onnx_policy = OnnxPolicy(agent.actor_mean)
346349
dummy_input = torch.unsqueeze(torch.tensor(envs.single_observation_space.sample()), 0)
347350

@@ -352,9 +355,9 @@ def forward(self, obs, state_ins):
352355
opset_version=15,
353356
input_names=["obs", "state_ins"],
354357
output_names=["output", "state_outs"],
355-
dynamic_axes={'obs' : {0 : 'batch_size'},
356-
'state_ins' : {0 : 'batch_size'}, # variable length axes
357-
'output' : {0 : 'batch_size'},
358-
'state_outs' : {0 : 'batch_size'}}
358+
dynamic_axes={'obs': {0: 'batch_size'},
359+
'state_ins': {0: 'batch_size'}, # variable length axes
360+
'output': {0: 'batch_size'},
361+
'state_outs': {0: 'batch_size'}}
359362

360-
)
363+
)

0 commit comments

Comments
 (0)