Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions godot_rl/wrappers/onnx/stable_baselines_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def export_model_as_onnx(model, onnx_model_path: str, use_obs_array: bool = Fals
args=(dummy_input, torch.zeros(1).float()),
f=onnx_model_path,
opset_version=17,
input_names=["obs", "state_ins"],
input_names=obs_keys + ["state_ins"],
output_names=["output", "state_outs"],
dynamic_axes={
"obs": {0: "batch_size"},
**{sub_obs: {0: "batch_size"} for sub_obs in obs_keys},
"state_ins": {0: "batch_size"}, # variable length axes
"output": {0: "batch_size"},
"state_outs": {0: "batch_size"},
Expand Down Expand Up @@ -119,16 +119,17 @@ def verify_onnx_export(ppo: PPO, onnx_model_path: str, num_tests=10, use_obs_arr
if use_obs_array:
obs = np.expand_dims(ppo.observation_space.sample(), axis=0)
obs2 = torch.tensor(obs)
obs = {"obs": obs}
else:
obs = dict(ppo.observation_space.sample())
obs_dict = dict(ppo.observation_space.sample())
obs2 = {}
for k, v in obs.items():
for k, v in obs_dict.items():
obs2[k] = torch.from_numpy(v).unsqueeze(0)
obs = [v for v in obs.values()]
obs = {k: [v] for k, v in obs_dict.items()}

with torch.no_grad():
action_sb3, _, _ = sb3_model(obs2, deterministic=True)

action_onnx, state_outs = ort_sess.run(None, {"obs": obs, "state_ins": np.array([0.0], dtype=np.float32)})
action_onnx, state_outs = ort_sess.run(None, {**obs, "state_ins": np.array([0.0], dtype=np.float32)})
assert np.allclose(action_sb3, action_onnx, atol=1e-5), "Mismatch in action output"
assert np.allclose(state_outs, np.array([0.0]), atol=1e-5), "Mismatch in state_outs output"