Skip to content

Commit 3dfec6b

Browse files
committed
add tests
1 parent 7d76a4b commit 3dfec6b

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

test/TorchSharpTest/TestDataLoader.cs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22

33
using System.Collections.Generic;
4+
using System.Linq;
5+
using TorchSharp.Modules;
46
using Xunit;
57

68

@@ -27,6 +29,26 @@ private class TestIterableDataset : torch.utils.data.IterableDataset
2729
}
2830
}
2931

32+
private class TestDatasetFromIEnumerable<T> : torch.utils.data.IDataset<T>
33+
{
34+
private readonly T[] values;
35+
public TestDatasetFromIEnumerable(IEnumerable<T> values)
36+
{
37+
this.values = values.ToArray();
38+
this.Disposed = false;
39+
}
40+
41+
public bool Disposed { get; set; }
42+
43+
public T this[long index] => values[index];
44+
45+
public long Count => values.LongLength;
46+
47+
public void Dispose() {
48+
this.Disposed = true;
49+
}
50+
}
51+
3052
[Fact]
3153
public void DatasetTest()
3254
{
@@ -230,5 +252,59 @@ public void CustomSeedTest()
230252
iterator.Dispose();
231253
iterator2.Dispose();
232254
}
255+
256+
[Fact]
257+
public void ConcatDatasetTest()
258+
{
259+
using var dataset1 = new TestDatasetFromIEnumerable<(int, int)>(new[] {
260+
(1, 1), // dataset 1 value 1
261+
(1, 2), // dataset 1 value 2
262+
(1, 3),
263+
});
264+
using var dataset2 = new TestDatasetFromIEnumerable<(int, int)>(new[] {
265+
(2, 1),
266+
(2, 2),
267+
});
268+
using var dataset3 = new TestDatasetFromIEnumerable<(int, int)>(new[] {
269+
(3, 1),
270+
(3, 2),
271+
(3, 3),
272+
(3, 4),
273+
});
274+
275+
using var dataset = new ConcatDataset<(int, int)>(new[] {
276+
dataset1, dataset2, dataset3
277+
});
278+
279+
Assert.Equal(3 + 2 + 4, dataset.Count);
280+
281+
Assert.Equal((1, 1), dataset[0]);
282+
Assert.Equal((1, 2), dataset[1]);
283+
Assert.Equal((1, 3), dataset[2]);
284+
Assert.Equal((2, 1), dataset[3]);
285+
Assert.Equal((2, 2), dataset[4]);
286+
Assert.Equal((3, 1), dataset[5]);
287+
Assert.Equal((3, 2), dataset[6]);
288+
Assert.Equal((3, 3), dataset[7]);
289+
Assert.Equal((3, 4), dataset[8]);
290+
291+
Assert.Equal((1, 1), dataset[-9]);
292+
Assert.Equal((1, 2), dataset[-8]);
293+
Assert.Equal((1, 3), dataset[-7]);
294+
Assert.Equal((2, 1), dataset[-6]);
295+
Assert.Equal((2, 2), dataset[-5]);
296+
Assert.Equal((3, 1), dataset[-4]);
297+
Assert.Equal((3, 2), dataset[-3]);
298+
Assert.Equal((3, 3), dataset[-2]);
299+
Assert.Equal((3, 4), dataset[-1]);
300+
301+
Assert.False(dataset1.Disposed);
302+
Assert.False(dataset2.Disposed);
303+
Assert.False(dataset3.Disposed);
304+
dataset.Dispose();
305+
Assert.True(dataset1.Disposed);
306+
Assert.True(dataset2.Disposed);
307+
Assert.True(dataset3.Disposed);
308+
}
233309
}
234310
}

0 commit comments

Comments
 (0)