Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions addons/godot_rl_agents/onnx/csharp/ONNXInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,54 @@ public partial class ONNXInference : GodotObject

private SessionOptions SessionOpt;

/// <summary>
/// init function
/// </summary>
/// <param name="Path"></param>
/// <param name="BatchSize"></param>
/// <returns>Returns the output size of the model</returns>
public int Initialize(string Path, int BatchSize)
/// <summary>
/// init function
/// </summary>
/// <param name="Path"></param>
/// <param name="BatchSize"></param>
/// <returns>Returns the output size of the model</returns>
public int Initialize(string Path, int BatchSize)
{
modelPath = Path;
batchSize = BatchSize;
SessionOpt = SessionConfigurator.MakeConfiguredSessionOptions();
session = LoadModel(modelPath);
return session.OutputMetadata["output"].Dimensions[1];
}
SessionOpt = SessionConfigurator.MakeConfiguredSessionOptions();
session = LoadModel(modelPath);
return session.OutputMetadata["output"].Dimensions[1];
}


/// <include file='docs/ONNXInference.xml' path='docs/members[@name="ONNXInference"]/Run/*'/>
public Godot.Collections.Dictionary<string, Godot.Collections.Array<float>> RunInference(Godot.Collections.Array<float> obs, int state_ins)
public Godot.Collections.Dictionary<string, Godot.Collections.Array<float>> RunInference(Godot.Collections.Dictionary<string, Godot.Collections.Array<float>> obs, int state_ins)
{
//Current model: Any (Godot Rl Agents)
//Expects a tensor of shape [batch_size, input_size] type float named obs and a tensor of shape [batch_size] type float named state_ins
//Expects a tensor of shape [batch_size, input_size] type float for any output of the agents observation dictionary and a tensor of shape [batch_size] type float named state_ins

//Fill the input tensors
// create span from inputSize
var span = new float[obs.Count]; //There's probably a better way to do this
for (int i = 0; i < obs.Count; i++)
var modelInputsList = new List<NamedOnnxValue>
{
span[i] = obs[i];
NamedOnnxValue.CreateFromTensor("state_ins", new DenseTensor<float>(new float[] { state_ins }, new int[] { batchSize }))
};
foreach (var key in obs.Keys)
{
var subObs = obs[key];
// Fill the input tensors for each key of the observation
// create span of observation from specific inputSize
var obsData = new float[subObs.Count]; //There's probably a better way to do this
for (int i = 0; i < subObs.Count; i++)
{
obsData[i] = subObs[i];
}
modelInputsList.Add(
NamedOnnxValue.CreateFromTensor(key, new DenseTensor<float>(obsData, new int[] { batchSize, subObs.Count }))
);
}

IReadOnlyCollection<NamedOnnxValue> inputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("obs", new DenseTensor<float>(span, new int[] { batchSize, obs.Count })),
NamedOnnxValue.CreateFromTensor("state_ins", new DenseTensor<float>(new float[] { state_ins }, new int[] { batchSize }))
};
IReadOnlyCollection<string> outputNames = new List<string> { "output", "state_outs" }; //ONNX is sensible to these names, as well as the input names

IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results;
IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results;
//We do not use "using" here so we get a better exception explaination later
try
{
results = session.Run(inputs, outputNames);
results = session.Run(modelInputsList, outputNames);
}
catch (OnnxRuntimeException e)
{
Expand Down
6 changes: 3 additions & 3 deletions addons/godot_rl_agents/onnx/wrapper/ONNX_wrapper.gd
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ func _init(model_path, batch_size):
action_output_size = inferencer.Initialize(model_path, batch_size)

# This function is the one that will be called from the game,
# requires the observation as an array and the state_ins as an int
# returns an Array containing the action the model takes.
func run_inference(obs: Array, state_ins: int) -> Dictionary:
# requires the observations as an Dictionary and the state_ins as an int
# returns a Dictionary containing the action the model takes.
func run_inference(obs: Dictionary, state_ins: int) -> Dictionary:
if inferencer == null:
printerr("Inferencer not initialized")
return {}
Expand Down
4 changes: 3 additions & 1 deletion addons/godot_rl_agents/sync.gd
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ func _inference_process():

for agent_id in range(0, agents_inference.size()):
var model: ONNXModel = agents_inference[agent_id].onnx_model
var action = model.run_inference(obs[agent_id]["obs"], 1.0)
var action = model.run_inference(
obs[agent_id], 1.0
)
var action_dict = _extract_action_dict(
action["output"], _action_space_inference[agent_id], model.action_means_only
)
Expand Down