@@ -81,10 +81,10 @@ def export_model_as_onnx(model, onnx_model_path: str, use_obs_array: bool = Fals
8181 args = (dummy_input , torch .zeros (1 ).float ()),
8282 f = onnx_model_path ,
8383 opset_version = 17 ,
84- input_names = [ "obs" , "state_ins" ],
84+ input_names = obs_keys + [ "state_ins" ],
8585 output_names = ["output" , "state_outs" ],
8686 dynamic_axes = {
87- "obs" : {0 : "batch_size" },
87+ ** { sub_obs : {0 : "batch_size" } for sub_obs in obs_keys },
8888 "state_ins" : {0 : "batch_size" }, # variable length axes
8989 "output" : {0 : "batch_size" },
9090 "state_outs" : {0 : "batch_size" },
@@ -119,16 +119,17 @@ def verify_onnx_export(ppo: PPO, onnx_model_path: str, num_tests=10, use_obs_arr
119119 if use_obs_array :
120120 obs = np .expand_dims (ppo .observation_space .sample (), axis = 0 )
121121 obs2 = torch .tensor (obs )
122+ obs = {"obs" : obs }
122123 else :
123- obs = dict (ppo .observation_space .sample ())
124+ obs_dict = dict (ppo .observation_space .sample ())
124125 obs2 = {}
125- for k , v in obs .items ():
126+ for k , v in obs_dict .items ():
126127 obs2 [k ] = torch .from_numpy (v ).unsqueeze (0 )
127- obs = [ v for v in obs . values ()]
128+ obs = { k : [ v ] for k , v in obs_dict . items ()}
128129
129130 with torch .no_grad ():
130131 action_sb3 , _ , _ = sb3_model (obs2 , deterministic = True )
131132
132- action_onnx , state_outs = ort_sess .run (None , {"obs" : obs , "state_ins" : np .array ([0.0 ], dtype = np .float32 )})
133+ action_onnx , state_outs = ort_sess .run (None , {** obs , "state_ins" : np .array ([0.0 ], dtype = np .float32 )})
133134 assert np .allclose (action_sb3 , action_onnx , atol = 1e-5 ), "Mismatch in action output"
134135 assert np .allclose (state_outs , np .array ([0.0 ]), atol = 1e-5 ), "Mismatch in state_outs output"
0 commit comments