Skip to content

Commit 7d76a4b

Browse files
committed
Create ConcatDataset.cs
1 parent dd2e1ed commit 7d76a4b

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

src/TorchSharp/ConcatDataset.cs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using TorchSharp.Modules;
6+
7+
namespace TorchSharp
8+
{
9+
public static partial class torch
10+
{
11+
public static partial class utils
12+
{
13+
public static partial class data
14+
{
15+
public static ConcatDataset<T> ConcatDataset<T>(IEnumerable<IDataset<T>> datasets)
16+
{
17+
return new ConcatDataset<T>(datasets);
18+
}
19+
}
20+
}
21+
}
22+
23+
namespace Modules
24+
{
25+
public class ConcatDataset<T> : torch.utils.data.Dataset<T>
26+
{
27+
private static IEnumerable<long> Cumsum(
28+
IEnumerable<torch.utils.data.IDataset<T>> datasets)
29+
{
30+
var s = 0L;
31+
foreach (var e in datasets) {
32+
s += e.Count;
33+
yield return s;
34+
}
35+
}
36+
private static long bisectRight(long[] a, long x)
37+
{
38+
var lo = 0;
39+
var hi = a.Length;
40+
while (lo < hi) {
41+
var mid = (lo + hi) / 2;
42+
if (x < a[mid])
43+
hi = mid;
44+
else
45+
lo = mid + 1;
46+
}
47+
return lo;
48+
}
49+
50+
51+
private readonly torch.utils.data.IDataset<T>[] _datasets;
52+
public IReadOnlyList<torch.utils.data.IDataset<T>> datasets => _datasets;
53+
54+
private readonly long[] _cumulativeSizes;
55+
public IReadOnlyList<long> cumulative_sizes => _cumulativeSizes;
56+
57+
private readonly bool autoDispose;
58+
59+
public ConcatDataset(
60+
IEnumerable<torch.utils.data.IDataset<T>> datasets,
61+
bool autoDispose = true)
62+
{
63+
this._datasets = datasets.ToArray();
64+
if (this._datasets.Length is 0)
65+
throw new ArgumentException(
66+
"datasets should not be an empty iterable", nameof(datasets));
67+
68+
// PyTorch also says 'ConcatDataset does not support IterableDataset'.
69+
// But it's not our torch.utils.data.IterableDataset in TorchSharp.
70+
this._cumulativeSizes = Cumsum(datasets).ToArray();
71+
72+
this.autoDispose = autoDispose;
73+
}
74+
75+
public override long Count => this._cumulativeSizes.Last();
76+
77+
public override T GetTensor(long index)
78+
{
79+
if (index < 0) {
80+
if (-index > this.Count) {
81+
throw new ArgumentException(
82+
"absolute value of index should not exceed dataset length",
83+
nameof(index));
84+
}
85+
index = this.Count + index;
86+
}
87+
88+
var datasetIdx = bisectRight(this._cumulativeSizes, index);
89+
long sampleIdx;
90+
if (datasetIdx == 0)
91+
sampleIdx = index;
92+
else
93+
sampleIdx = index - this._cumulativeSizes[datasetIdx - 1];
94+
return this._datasets[datasetIdx][sampleIdx];
95+
}
96+
97+
protected override void Dispose(bool disposing)
98+
{
99+
if (disposing && autoDispose) {
100+
foreach (var dataset in this._datasets)
101+
dataset.Dispose();
102+
}
103+
104+
base.Dispose(disposing);
105+
}
106+
}
107+
}
108+
}

0 commit comments

Comments
 (0)