Skip to content

Commit 3e6ba78

Browse files
committed
Multiple LSTM cell handling added to Barracuda code path
1 parent a0b2743 commit 3e6ba78

File tree

7 files changed

+71
-41
lines changed

7 files changed

+71
-41
lines changed

UnitySDK/Assets/ML-Agents/Scripts/Agent.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,11 @@ public void AppendMemoriesAction(List<float> memories)
918918
action.memories.AddRange(memories);
919919
}
920920

921+
public List<float> GetMemoriesAction()
922+
{
923+
return action.memories;
924+
}
925+
921926
/// <summary>
922927
/// Updates the text action.
923928
/// </summary>

UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Collections.Generic;
2+
using System.Linq;
23
using MLAgents.InferenceBrain.Utils;
34
using UnityEngine;
45

@@ -99,34 +100,38 @@ public void Apply(Tensor tensor, Dictionary<Agent, AgentInfo> agentInfo)
99100

100101
public class BarracudaMemoryOutputApplier : TensorApplier.Applier
101102
{
102-
private bool firstHalf = true;
103+
private int memoriesCount;
104+
private int memoryIndex;
103105

104-
public BarracudaMemoryOutputApplier(bool firstHalf)
106+
public BarracudaMemoryOutputApplier(int memoriesCount, int memoryIndex)
105107
{
106-
this.firstHalf = firstHalf;
108+
this.memoriesCount = memoriesCount;
109+
this.memoryIndex = memoryIndex;
107110
}
108111

109112
public void Apply(Tensor tensor, Dictionary<Agent, AgentInfo> agentInfo)
110113
{
111114
var tensorDataMemory = tensor.Data as float[,];
112115
var agentIndex = 0;
113-
var memorySize = tensor.Shape[tensor.Shape.Length - 1];
116+
var memorySize = (int)tensor.Shape[tensor.Shape.Length - 1];
117+
114118
foreach (var agent in agentInfo.Keys)
115119
{
116-
var memory = new List<float>();
117-
for (var j = 0; j < memorySize; j++)
118-
{
119-
memory.Add(tensorDataMemory[agentIndex, j]);
120-
}
120+
var memory = agent.GetMemoriesAction();
121121

122-
if (firstHalf)
122+
if (memory == null || memory.Count < memorySize * memoriesCount)
123123
{
124-
agent.UpdateMemoriesAction(memory);
124+
memory = new List<float>();
125+
memory.AddRange(Enumerable.Repeat(0f, memorySize * memoriesCount));
125126
}
126-
else
127+
128+
for (var j = 0; j < memorySize; j++)
127129
{
128-
agent.AppendMemoriesAction(memory);
130+
memory[memorySize * memoryIndex + j] = tensorDataMemory[agentIndex, j];
129131
}
132+
133+
agent.UpdateMemoriesAction(memory);
134+
130135
agentIndex++;
131136
}
132137
}

UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/BarracudaModelParamLoader.cs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,10 @@ public string[] GetOutputNames()
109109
var memory = GetIntScalar(TensorNames.MemorySize);
110110
if (memory > 0)
111111
{
112-
names.Add(TensorNames.RecurrentOutput_C);
113-
names.Add(TensorNames.RecurrentOutput_H);
112+
foreach (var mem in _model.memories)
113+
{
114+
names.Add(mem.output);
115+
}
114116
}
115117

116118
names.Sort();
@@ -264,8 +266,8 @@ private void CheckInputTensorPresence(int memory, ModelActionType isContinuous)
264266
// If the model has a non-negative memory size but requires a recurrent input
265267
if (memory > 0)
266268
{
267-
if (!tensorsNames.Contains(TensorNames.RecurrentInPlaceholder_H) ||
268-
!tensorsNames.Contains(TensorNames.RecurrentInPlaceholder_C))
269+
if (!tensorsNames.Any(x => x.EndsWith("_h")) ||
270+
!tensorsNames.Any(x => x.EndsWith("_c")))
269271
{
270272
_failedModelChecks.Add(
271273
"The model does not contain a Recurrent Input Node but has memory_size.");
@@ -302,8 +304,8 @@ private void CheckOutputTensorPresence(int memory)
302304
{
303305
var memOutputs = _model.memories.Select(x => x.output).ToList();
304306

305-
if (!memOutputs.Contains(TensorNames.RecurrentOutput_H) ||
306-
!memOutputs.Contains(TensorNames.RecurrentOutput_C))
307+
if (!memOutputs.Any(x => x.EndsWith("_h")) ||
308+
!memOutputs.Any(x => x.EndsWith("_c")))
307309
{
308310
_failedModelChecks.Add(
309311
"The model does not contain a Recurrent Output Node but has memory_size.");
@@ -325,9 +327,12 @@ private void CheckInputTensorShape()
325327
{TensorNames.RandomNormalEpsilonPlaceholder, ((tensor) => null)},
326328
{TensorNames.ActionMaskPlaceholder, ((tensor) => null)},
327329
{TensorNames.SequenceLengthPlaceholder, ((tensor) => null)},
328-
{TensorNames.RecurrentInPlaceholder_H, ((tensor) => null)},
329-
{TensorNames.RecurrentInPlaceholder_C, ((tensor) => null)},
330+
{TensorNames.RecurrentInPlaceholder, ((tensor) => null)},
330331
};
332+
333+
foreach (var mem in _model.memories)
334+
tensorTester[mem.input] = ((tensor) => null);
335+
331336
for (var obsIndex = 0; obsIndex < _brainParameters.cameraResolutions.Length; obsIndex++)
332337
{
333338
var index = obsIndex;

UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Collections.Generic;
22
using System;
33
using System.Linq;
4+
using Barracuda;
45
using MLAgents.InferenceBrain.Utils;
56

67
namespace MLAgents.InferenceBrain
@@ -119,28 +120,28 @@ public void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo>
119120

120121
public class BarracudaRecurrentInputGenerator : TensorGenerator.Generator
121122
{
122-
private bool firstHalf = true;
123+
private int memoriesCount;
124+
private int memoryIndex;
123125

124-
public BarracudaRecurrentInputGenerator(bool firstHalf)
126+
public BarracudaRecurrentInputGenerator(int memoriesCount, int memoryIndex)
125127
{
126-
this.firstHalf = firstHalf;
128+
this.memoriesCount = memoriesCount;
129+
this.memoryIndex = memoryIndex;
127130
}
128131

129132
public void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
130133
{
131134
tensor.Shape[0] = batchSize;
132-
var memorySize = tensor.Shape[tensor.Shape.Length - 1];
135+
136+
var memorySize = (int)tensor.Shape[tensor.Shape.Length - 1];
137+
133138
tensor.Data = new float[batchSize, memorySize];
134139
var agentIndex = 0;
135140
foreach (var agent in agentInfo.Keys)
136141
{
137-
var memory = agentInfo[agent].memories;
142+
var memory = agentInfo[agent].memories;
138143

139-
int offset = 0;
140-
if (!firstHalf)
141-
{
142-
offset = memory.Count - (int)memorySize;
143-
}
144+
int offset = memorySize * memoryIndex;
144145

145146
if (memory == null)
146147
{

UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Collections.Generic;
1+
#define ENABLE_BARRACUDA
2+
using System.Collections.Generic;
23

34
namespace MLAgents.InferenceBrain
45
{
@@ -37,7 +38,7 @@ public interface Applier
3738
/// <param name="bp"> The BrainParameters used to determine what Appliers will be
3839
/// used</param>
3940
/// <param name="seed"> The seed the Appliers will be initialized with.</param>
40-
public TensorApplier(BrainParameters bp, int seed)
41+
public TensorApplier(BrainParameters bp, int seed, object barracudaModel = null)
4142
{
4243
_dict[TensorNames.ValueEstimateOutput] = new ValueEstimateApplier();
4344
if (bp.vectorActionSpaceType == SpaceType.continuous)
@@ -51,8 +52,14 @@ public TensorApplier(BrainParameters bp, int seed)
5152
}
5253
_dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier();
5354

54-
_dict[TensorNames.RecurrentOutput_C] = new BarracudaMemoryOutputApplier(true);
55-
_dict[TensorNames.RecurrentOutput_H] = new BarracudaMemoryOutputApplier(false);
55+
#if ENABLE_BARRACUDA
56+
Barracuda.Model model = (Barracuda.Model) barracudaModel;
57+
58+
for (var i = 0; i < model?.memories.Length; i++)
59+
{
60+
_dict[model.memories[i].output] = new BarracudaMemoryOutputApplier(model.memories.Length, i);
61+
}
62+
#endif
5663
}
5764

5865
/// <summary>

UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using System.Collections.Generic;
1+
#define ENABLE_BARRACUDA
2+
using System.Collections.Generic;
3+
using Barracuda;
24

35
namespace MLAgents.InferenceBrain
46
{
@@ -37,16 +39,21 @@ public interface Generator
3739
/// <param name="bp"> The BrainParameters used to determine what Generators will be
3840
/// used</param>
3941
/// <param name="seed"> The seed the Generators will be initialized with.</param>
40-
public TensorGenerator(BrainParameters bp, int seed)
42+
public TensorGenerator(BrainParameters bp, int seed, object barracudaModel = null)
4143
{
4244
// Generator for Inputs
4345
_dict[TensorNames.BatchSizePlaceholder] = new BatchSizeGenerator();
4446
_dict[TensorNames.SequenceLengthPlaceholder] = new SequenceLengthGenerator();
4547
_dict[TensorNames.VectorObservationPlacholder] = new VectorObservationGenerator();
4648
_dict[TensorNames.RecurrentInPlaceholder] = new RecurrentInputGenerator();
4749

48-
_dict[TensorNames.RecurrentInPlaceholder_C] = new BarracudaRecurrentInputGenerator(true);
49-
_dict[TensorNames.RecurrentInPlaceholder_H] = new BarracudaRecurrentInputGenerator(false);
50+
#if ENABLE_BARRACUDA
51+
Barracuda.Model model = (Barracuda.Model) barracudaModel;
52+
for (var i = 0; i < model?.memories.Length; i++)
53+
{
54+
_dict[model.memories[i].input] = new BarracudaRecurrentInputGenerator(model.memories.Length, i);
55+
}
56+
#endif
5057

5158
_dict[TensorNames.PreviousActionPlaceholder] = new PreviousActionInputGenerator();
5259
_dict[TensorNames.ActionMaskPlaceholder] = new ActionMaskInputGenerator();

UnitySDK/Assets/ML-Agents/Scripts/LearningBrain.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ public void ReloadModel(int seed = 0)
126126
_modelParamLoader = BarracudaModelParamLoader.GetLoaderAndCheck(_engine, _barracudaModel, brainParameters);
127127
_inferenceInputs = _modelParamLoader.GetInputTensors();
128128
_outputNames = _modelParamLoader.GetOutputNames();
129-
_tensorGenerator = new TensorGenerator(brainParameters, seed);
130-
_tensorApplier = new TensorApplier(brainParameters, seed);
129+
_tensorGenerator = new TensorGenerator(brainParameters, seed, _barracudaModel);
130+
_tensorApplier = new TensorApplier(brainParameters, seed, _barracudaModel);
131131
#endif
132132
}
133133

0 commit comments

Comments
 (0)