Skip to content

Commit cc1f2d1

Browse files
authored
Merge pull request #231 from Nijco/main
onnx export: properly handle obs_keys
2 parents 00502fd + 7ae1aaa commit cc1f2d1

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

godot_rl/wrappers/onnx/stable_baselines_export.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ def export_model_as_onnx(model, onnx_model_path: str, use_obs_array: bool = Fals
8181
args=(dummy_input, torch.zeros(1).float()),
8282
f=onnx_model_path,
8383
opset_version=17,
84-
input_names=["obs", "state_ins"],
84+
input_names=obs_keys + ["state_ins"],
8585
output_names=["output", "state_outs"],
8686
dynamic_axes={
87-
"obs": {0: "batch_size"},
87+
**{sub_obs: {0: "batch_size"} for sub_obs in obs_keys},
8888
"state_ins": {0: "batch_size"}, # variable length axes
8989
"output": {0: "batch_size"},
9090
"state_outs": {0: "batch_size"},
@@ -119,16 +119,17 @@ def verify_onnx_export(ppo: PPO, onnx_model_path: str, num_tests=10, use_obs_arr
119119
if use_obs_array:
120120
obs = np.expand_dims(ppo.observation_space.sample(), axis=0)
121121
obs2 = torch.tensor(obs)
122+
obs = {"obs": obs}
122123
else:
123-
obs = dict(ppo.observation_space.sample())
124+
obs_dict = dict(ppo.observation_space.sample())
124125
obs2 = {}
125-
for k, v in obs.items():
126+
for k, v in obs_dict.items():
126127
obs2[k] = torch.from_numpy(v).unsqueeze(0)
127-
obs = [v for v in obs.values()]
128+
obs = {k: [v] for k, v in obs_dict.items()}
128129

129130
with torch.no_grad():
130131
action_sb3, _, _ = sb3_model(obs2, deterministic=True)
131132

132-
action_onnx, state_outs = ort_sess.run(None, {"obs": obs, "state_ins": np.array([0.0], dtype=np.float32)})
133+
action_onnx, state_outs = ort_sess.run(None, {**obs, "state_ins": np.array([0.0], dtype=np.float32)})
133134
assert np.allclose(action_sb3, action_onnx, atol=1e-5), "Mismatch in action output"
134135
assert np.allclose(state_outs, np.array([0.0]), atol=1e-5), "Mismatch in state_outs output"

0 commit comments

Comments
 (0)