Skip to content

Commit b78c1e0

Browse files
mantaspErvin T
authored andcommitted
Fix tests for Barracuda (#2333)
* Removed obsolete 'TestDstWrongShape' test as it does not reflect how Barracuda tensors work * Added proper test cleanup, to avoid warning messages from finalizer thread.
1 parent 5837c71 commit b78c1e0

File tree

3 files changed

+27
-33
lines changed

3 files changed

+27
-33
lines changed

UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ private Dictionary<Agent, AgentInfo> GetFakeAgentInfos()
3636
public void Contruction()
3737
{
3838
var bp = new BrainParameters();
39-
var tensorGenerator = new TensorApplier(bp, 0, new TensorCachingAllocator());
39+
var alloc = new TensorCachingAllocator();
40+
var tensorGenerator = new TensorApplier(bp, 0, alloc);
4041
Assert.IsNotNull(tensorGenerator);
42+
alloc.Dispose();
4143
}
4244

4345
[Test]
@@ -76,8 +78,8 @@ public void ApplyDiscreteActionOutput()
7678
4f, 5f, 6f, 7f, 8f})
7779
};
7880
var agentInfos = GetFakeAgentInfos();
79-
80-
var applier = new DiscreteActionOutputApplier(new int[]{2, 3}, 0, new TensorCachingAllocator());
81+
var alloc = new TensorCachingAllocator();
82+
var applier = new DiscreteActionOutputApplier(new int[]{2, 3}, 0, alloc);
8183
applier.Apply(inputTensor, agentInfos);
8284
var agents = agentInfos.Keys.ToList();
8385
var agent = agents[0] as TestAgent;
@@ -88,6 +90,7 @@ public void ApplyDiscreteActionOutput()
8890
action = agent.GetAction();
8991
Assert.AreEqual(action.vectorActions[0], 1);
9092
Assert.AreEqual(action.vectorActions[1], 2);
93+
alloc.Dispose();
9194
}
9295

9396
[Test]

UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,30 +44,36 @@ private Dictionary<Agent, AgentInfo> GetFakeAgentInfos()
4444
public void Contruction()
4545
{
4646
var bp = new BrainParameters();
47-
var tensorGenerator = new TensorGenerator(bp, 0, new TensorCachingAllocator());
47+
var alloc = new TensorCachingAllocator();
48+
var tensorGenerator = new TensorGenerator(bp, 0, alloc);
4849
Assert.IsNotNull(tensorGenerator);
50+
alloc.Dispose();
4951
}
5052

5153
[Test]
5254
public void GenerateBatchSize()
5355
{
5456
var inputTensor = new TensorProxy();
57+
var alloc = new TensorCachingAllocator();
5558
var batchSize = 4;
56-
var generator = new BatchSizeGenerator(new TensorCachingAllocator());
59+
var generator = new BatchSizeGenerator(alloc);
5760
generator.Generate(inputTensor, batchSize, null);
5861
Assert.IsNotNull(inputTensor.Data);
5962
Assert.AreEqual(inputTensor.Data[0], batchSize);
63+
alloc.Dispose();
6064
}
6165

6266
[Test]
6367
public void GenerateSequenceLength()
6468
{
6569
var inputTensor = new TensorProxy();
70+
var alloc = new TensorCachingAllocator();
6671
var batchSize = 4;
67-
var generator = new SequenceLengthGenerator(new TensorCachingAllocator());
72+
var generator = new SequenceLengthGenerator(alloc);
6873
generator.Generate(inputTensor, batchSize, null);
6974
Assert.IsNotNull(inputTensor.Data);
7075
Assert.AreEqual(inputTensor.Data[0], 1);
76+
alloc.Dispose();
7177
}
7278

7379
[Test]
@@ -79,14 +85,15 @@ public void GenerateVectorObservation()
7985
};
8086
var batchSize = 4;
8187
var agentInfos = GetFakeAgentInfos();
82-
83-
var generator = new VectorObservationGenerator(new TensorCachingAllocator());
88+
var alloc = new TensorCachingAllocator();
89+
var generator = new VectorObservationGenerator(alloc);
8490
generator.Generate(inputTensor, batchSize, agentInfos);
8591
Assert.IsNotNull(inputTensor.Data);
8692
Assert.AreEqual(inputTensor.Data[0, 0], 1);
8793
Assert.AreEqual(inputTensor.Data[0, 2], 3);
8894
Assert.AreEqual(inputTensor.Data[1, 0], 4);
8995
Assert.AreEqual(inputTensor.Data[1, 2], 6);
96+
alloc.Dispose();
9097
}
9198

9299
[Test]
@@ -98,14 +105,15 @@ public void GenerateRecurrentInput()
98105
};
99106
var batchSize = 4;
100107
var agentInfos = GetFakeAgentInfos();
101-
102-
var generator = new RecurrentInputGenerator(new TensorCachingAllocator());
108+
var alloc = new TensorCachingAllocator();
109+
var generator = new RecurrentInputGenerator(alloc);
103110
generator.Generate(inputTensor, batchSize, agentInfos);
104111
Assert.IsNotNull(inputTensor.Data);
105112
Assert.AreEqual(inputTensor.Data[0, 0], 0);
106113
Assert.AreEqual(inputTensor.Data[0, 4], 0);
107114
Assert.AreEqual(inputTensor.Data[1, 0], 1);
108115
Assert.AreEqual(inputTensor.Data[1, 4], 0);
116+
alloc.Dispose();
109117
}
110118

111119
[Test]
@@ -119,15 +127,16 @@ public void GeneratePreviousActionInput()
119127
};
120128
var batchSize = 4;
121129
var agentInfos = GetFakeAgentInfos();
122-
123-
var generator = new PreviousActionInputGenerator(new TensorCachingAllocator());
130+
var alloc = new TensorCachingAllocator();
131+
var generator = new PreviousActionInputGenerator(alloc);
124132

125133
generator.Generate(inputTensor, batchSize, agentInfos);
126134
Assert.IsNotNull(inputTensor.Data);
127135
Assert.AreEqual(inputTensor.Data[0, 0], 1);
128136
Assert.AreEqual(inputTensor.Data[0, 1], 2);
129137
Assert.AreEqual(inputTensor.Data[1, 0], 3);
130138
Assert.AreEqual(inputTensor.Data[1, 1], 4);
139+
alloc.Dispose();
131140
}
132141

133142
[Test]
@@ -141,14 +150,15 @@ public void GenerateActionMaskInput()
141150
};
142151
var batchSize = 4;
143152
var agentInfos = GetFakeAgentInfos();
144-
145-
var generator = new ActionMaskInputGenerator(new TensorCachingAllocator());
153+
var alloc = new TensorCachingAllocator();
154+
var generator = new ActionMaskInputGenerator(alloc);
146155
generator.Generate(inputTensor, batchSize, agentInfos);
147156
Assert.IsNotNull(inputTensor.Data);
148157
Assert.AreEqual(inputTensor.Data[0, 0], 1);
149158
Assert.AreEqual(inputTensor.Data[0, 4], 1);
150159
Assert.AreEqual(inputTensor.Data[1, 0], 0);
151160
Assert.AreEqual(inputTensor.Data[1, 4], 1);
161+
alloc.Dispose();
152162
}
153163
}
154164
}

UnitySDK/Assets/ML-Agents/Editor/Tests/MultinomialTest.cs

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -160,25 +160,6 @@ public void TestDstDataNull()
160160
Assert.Throws<ArgumentNullException>(() => m.Eval(src, dst));
161161
}
162162

163-
[Test]
164-
public void TestDstWrongShape()
165-
{
166-
Multinomial m = new Multinomial(2018);
167-
168-
TensorProxy src = new TensorProxy
169-
{
170-
ValueType = TensorProxy.TensorType.FloatingPoint,
171-
Data = new Tensor(0,1)
172-
};
173-
TensorProxy dst = new TensorProxy
174-
{
175-
ValueType = TensorProxy.TensorType.FloatingPoint,
176-
Data = new Tensor(0,2)
177-
};
178-
179-
Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
180-
}
181-
182163
[Test]
183164
public void TestUnequalBatchSize()
184165
{

0 commit comments

Comments
 (0)