diff --git a/addons/godot_rl_agents/onnx/csharp/ONNXInference.cs b/addons/godot_rl_agents/onnx/csharp/ONNXInference.cs
index 6dcfa18..ec845ed 100644
--- a/addons/godot_rl_agents/onnx/csharp/ONNXInference.cs
+++ b/addons/godot_rl_agents/onnx/csharp/ONNXInference.cs
@@ -19,48 +19,54 @@ public partial class ONNXInference : GodotObject
private SessionOptions SessionOpt;
- ///
- /// init function
- ///
- ///
- ///
- /// Returns the output size of the model
- public int Initialize(string Path, int BatchSize)
+ ///
+ /// init function
+ ///
+ ///
+ ///
+ /// Returns the output size of the model
+ public int Initialize(string Path, int BatchSize)
{
modelPath = Path;
batchSize = BatchSize;
- SessionOpt = SessionConfigurator.MakeConfiguredSessionOptions();
- session = LoadModel(modelPath);
- return session.OutputMetadata["output"].Dimensions[1];
- }
+ SessionOpt = SessionConfigurator.MakeConfiguredSessionOptions();
+ session = LoadModel(modelPath);
+ return session.OutputMetadata["output"].Dimensions[1];
+ }
///
- public Godot.Collections.Dictionary> RunInference(Godot.Collections.Array obs, int state_ins)
+ public Godot.Collections.Dictionary> RunInference(Godot.Collections.Dictionary> obs, int state_ins)
{
//Current model: Any (Godot Rl Agents)
- //Expects a tensor of shape [batch_size, input_size] type float named obs and a tensor of shape [batch_size] type float named state_ins
+ //Expects a tensor of shape [batch_size, input_size] type float for any output of the agents observation dictionary and a tensor of shape [batch_size] type float named state_ins
- //Fill the input tensors
- // create span from inputSize
- var span = new float[obs.Count]; //There's probably a better way to do this
- for (int i = 0; i < obs.Count; i++)
+ var modelInputsList = new List
{
- span[i] = obs[i];
+ NamedOnnxValue.CreateFromTensor("state_ins", new DenseTensor(new float[] { state_ins }, new int[] { batchSize }))
+ };
+ foreach (var key in obs.Keys)
+ {
+ var subObs = obs[key];
+ // Fill the input tensors for each key of the observation
+ // create span of observation from specific inputSize
+ var obsData = new float[subObs.Count]; //There's probably a better way to do this
+ for (int i = 0; i < subObs.Count; i++)
+ {
+ obsData[i] = subObs[i];
+ }
+ modelInputsList.Add(
+ NamedOnnxValue.CreateFromTensor(key, new DenseTensor(obsData, new int[] { batchSize, subObs.Count }))
+ );
}
- IReadOnlyCollection inputs = new List
- {
- NamedOnnxValue.CreateFromTensor("obs", new DenseTensor(span, new int[] { batchSize, obs.Count })),
- NamedOnnxValue.CreateFromTensor("state_ins", new DenseTensor(new float[] { state_ins }, new int[] { batchSize }))
- };
IReadOnlyCollection outputNames = new List { "output", "state_outs" }; //ONNX is sensible to these names, as well as the input names
- IDisposableReadOnlyCollection results;
+ IDisposableReadOnlyCollection results;
//We do not use "using" here so we get a better exception explaination later
try
{
- results = session.Run(inputs, outputNames);
+ results = session.Run(modelInputsList, outputNames);
}
catch (OnnxRuntimeException e)
{
diff --git a/addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd b/addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd
index e27f2c3..7d29a03 100644
--- a/addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd
+++ b/addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd
@@ -21,9 +21,9 @@ func _init(model_path, batch_size):
action_output_size = inferencer.Initialize(model_path, batch_size)
# This function is the one that will be called from the game,
-# requires the observation as an array and the state_ins as an int
-# returns an Array containing the action the model takes.
-func run_inference(obs: Array, state_ins: int) -> Dictionary:
+# requires the observations as an Dictionary and the state_ins as an int
+# returns a Dictionary containing the action the model takes.
+func run_inference(obs: Dictionary, state_ins: int) -> Dictionary:
if inferencer == null:
printerr("Inferencer not initialized")
return {}
diff --git a/addons/godot_rl_agents/sync.gd b/addons/godot_rl_agents/sync.gd
index f47decb..20909e9 100644
--- a/addons/godot_rl_agents/sync.gd
+++ b/addons/godot_rl_agents/sync.gd
@@ -230,7 +230,9 @@ func _inference_process():
for agent_id in range(0, agents_inference.size()):
var model: ONNXModel = agents_inference[agent_id].onnx_model
- var action = model.run_inference(obs[agent_id]["obs"], 1.0)
+ var action = model.run_inference(
+ obs[agent_id], 1.0
+ )
var action_dict = _extract_action_dict(
action["output"], _action_space_inference[agent_id], model.action_means_only
)