diff --git a/addons/godot_rl_agents/onnx/csharp/ONNXInference.cs b/addons/godot_rl_agents/onnx/csharp/ONNXInference.cs index 6dcfa18..ec845ed 100644 --- a/addons/godot_rl_agents/onnx/csharp/ONNXInference.cs +++ b/addons/godot_rl_agents/onnx/csharp/ONNXInference.cs @@ -19,48 +19,54 @@ public partial class ONNXInference : GodotObject private SessionOptions SessionOpt; - /// - /// init function - /// - /// - /// - /// Returns the output size of the model - public int Initialize(string Path, int BatchSize) + /// + /// init function + /// + /// + /// + /// Returns the output size of the model + public int Initialize(string Path, int BatchSize) { modelPath = Path; batchSize = BatchSize; - SessionOpt = SessionConfigurator.MakeConfiguredSessionOptions(); - session = LoadModel(modelPath); - return session.OutputMetadata["output"].Dimensions[1]; - } + SessionOpt = SessionConfigurator.MakeConfiguredSessionOptions(); + session = LoadModel(modelPath); + return session.OutputMetadata["output"].Dimensions[1]; + } /// - public Godot.Collections.Dictionary> RunInference(Godot.Collections.Array obs, int state_ins) + public Godot.Collections.Dictionary> RunInference(Godot.Collections.Dictionary> obs, int state_ins) { //Current model: Any (Godot Rl Agents) - //Expects a tensor of shape [batch_size, input_size] type float named obs and a tensor of shape [batch_size] type float named state_ins + //Expects a tensor of shape [batch_size, input_size] type float for any output of the agents observation dictionary and a tensor of shape [batch_size] type float named state_ins - //Fill the input tensors - // create span from inputSize - var span = new float[obs.Count]; //There's probably a better way to do this - for (int i = 0; i < obs.Count; i++) + var modelInputsList = new List { - span[i] = obs[i]; + NamedOnnxValue.CreateFromTensor("state_ins", new DenseTensor(new float[] { state_ins }, new int[] { batchSize })) + }; + foreach (var key in obs.Keys) + { + var subObs = obs[key]; + // Fill the input tensors for each key of the observation + // create span of observation from specific inputSize + var obsData = new float[subObs.Count]; //There's probably a better way to do this + for (int i = 0; i < subObs.Count; i++) + { + obsData[i] = subObs[i]; + } + modelInputsList.Add( + NamedOnnxValue.CreateFromTensor(key, new DenseTensor(obsData, new int[] { batchSize, subObs.Count })) + ); } - IReadOnlyCollection inputs = new List - { - NamedOnnxValue.CreateFromTensor("obs", new DenseTensor(span, new int[] { batchSize, obs.Count })), - NamedOnnxValue.CreateFromTensor("state_ins", new DenseTensor(new float[] { state_ins }, new int[] { batchSize })) - }; IReadOnlyCollection outputNames = new List { "output", "state_outs" }; //ONNX is sensible to these names, as well as the input names - IDisposableReadOnlyCollection results; + IDisposableReadOnlyCollection results; //We do not use "using" here so we get a better exception explaination later try { - results = session.Run(inputs, outputNames); + results = session.Run(modelInputsList, outputNames); } catch (OnnxRuntimeException e) { diff --git a/addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd b/addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd index e27f2c3..7d29a03 100644 --- a/addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd +++ b/addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd @@ -21,9 +21,9 @@ func _init(model_path, batch_size): action_output_size = inferencer.Initialize(model_path, batch_size) # This function is the one that will be called from the game, -# requires the observation as an array and the state_ins as an int -# returns an Array containing the action the model takes. -func run_inference(obs: Array, state_ins: int) -> Dictionary: +# requires the observations as an Dictionary and the state_ins as an int +# returns a Dictionary containing the action the model takes. +func run_inference(obs: Dictionary, state_ins: int) -> Dictionary: if inferencer == null: printerr("Inferencer not initialized") return {} diff --git a/addons/godot_rl_agents/sync.gd b/addons/godot_rl_agents/sync.gd index f47decb..20909e9 100644 --- a/addons/godot_rl_agents/sync.gd +++ b/addons/godot_rl_agents/sync.gd @@ -230,7 +230,9 @@ func _inference_process(): for agent_id in range(0, agents_inference.size()): var model: ONNXModel = agents_inference[agent_id].onnx_model - var action = model.run_inference(obs[agent_id]["obs"], 1.0) + var action = model.run_inference( + obs[agent_id], 1.0 + ) var action_dict = _extract_action_dict( action["output"], _action_space_inference[agent_id], model.action_means_only )