Skip to content

Commit 35f4c7d

Browse files
Merge pull request #1354 from yueyinqiu/ConcatDataset
concat dataset
2 parents c332e9d + 94635e9 commit 35f4c7d

File tree

6 files changed

+281
-36
lines changed

6 files changed

+281
-36
lines changed

RELEASENOTES.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22

33
Releases, starting with 9/2/2021, are listed with the most recent release at the top.
44

5+
# NuGet Version 0.102.7
6+
7+
__Breaking Changes__:
8+
9+
A new interface `IDataset<out T>` has been added. (Now `Dataset<T>` implements `IDataset<T>`; `Dataset` implements both `IDataset<Dictionary<string, Tensor>>` and `IDataset<IReadOnlyDictionary<string, Tensor>>`; `IterableDataset` implements `IDataset<IList<string, Tensor>>` and `IDataset<IEnumerable<string, Tensor>>`.)<br/>
10+
`torch.utils.data.ConcatDataset` has been added.<br/>
11+
12+
__API Changes__:
13+
14+
The parameter of `DataLoader`s has been relaxed to `IDataset`.<br/>
15+
The parameter of `DataLoader`s' collate functions has been relaxed to `IReadOnlyList`.<br/>
16+
517
# NuGet Version 0.102.6
618

719
__Breaking Changes__:

src/TorchSharp/ConcatDataset.cs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using TorchSharp.Modules;
6+
7+
namespace TorchSharp
8+
{
9+
public static partial class torch
10+
{
11+
public static partial class utils
12+
{
13+
public static partial class data
14+
{
15+
public static ConcatDataset<T> ConcatDataset<T>(IEnumerable<IDataset<T>> datasets)
16+
{
17+
return new ConcatDataset<T>(datasets);
18+
}
19+
}
20+
}
21+
}
22+
23+
namespace Modules
24+
{
25+
public class ConcatDataset<T> : torch.utils.data.Dataset<T>
26+
{
27+
private static IEnumerable<long> Cumsum(
28+
IEnumerable<torch.utils.data.IDataset<T>> datasets)
29+
{
30+
var s = 0L;
31+
foreach (var e in datasets) {
32+
s += e.Count;
33+
yield return s;
34+
}
35+
}
36+
private static long bisectRight(long[] a, long x)
37+
{
38+
var lo = 0;
39+
var hi = a.Length;
40+
while (lo < hi) {
41+
var mid = (lo + hi) / 2;
42+
if (x < a[mid])
43+
hi = mid;
44+
else
45+
lo = mid + 1;
46+
}
47+
return lo;
48+
}
49+
50+
51+
private readonly torch.utils.data.IDataset<T>[] _datasets;
52+
public IReadOnlyList<torch.utils.data.IDataset<T>> datasets => _datasets;
53+
54+
private readonly long[] _cumulativeSizes;
55+
public IReadOnlyList<long> cumulative_sizes => _cumulativeSizes;
56+
57+
private readonly bool autoDispose;
58+
59+
public ConcatDataset(
60+
IEnumerable<torch.utils.data.IDataset<T>> datasets,
61+
bool autoDispose = true)
62+
{
63+
this._datasets = datasets.ToArray();
64+
if (this._datasets.Length is 0)
65+
throw new ArgumentException(
66+
"datasets should not be an empty iterable", nameof(datasets));
67+
68+
// PyTorch also says 'ConcatDataset does not support IterableDataset'.
69+
// But it's not our torch.utils.data.IterableDataset in TorchSharp.
70+
this._cumulativeSizes = Cumsum(datasets).ToArray();
71+
72+
this.autoDispose = autoDispose;
73+
}
74+
75+
public override long Count => this._cumulativeSizes.Last();
76+
77+
public override T GetTensor(long index)
78+
{
79+
if (index < 0) {
80+
if (-index > this.Count) {
81+
throw new ArgumentException(
82+
"absolute value of index should not exceed dataset length",
83+
nameof(index));
84+
}
85+
index = this.Count + index;
86+
}
87+
88+
var datasetIdx = bisectRight(this._cumulativeSizes, index);
89+
long sampleIdx;
90+
if (datasetIdx == 0)
91+
sampleIdx = index;
92+
else
93+
sampleIdx = index - this._cumulativeSizes[datasetIdx - 1];
94+
return this._datasets[datasetIdx][sampleIdx];
95+
}
96+
97+
protected override void Dispose(bool disposing)
98+
{
99+
if (disposing && autoDispose) {
100+
foreach (var dataset in this._datasets)
101+
dataset.Dispose();
102+
}
103+
104+
base.Dispose(disposing);
105+
}
106+
}
107+
}
108+
}

src/TorchSharp/DataLoader.cs

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ public static partial class utils
1717
{
1818
public static partial class data
1919
{
20-
2120
public static Modules.DataLoader DataLoader(
22-
Dataset dataset,
21+
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
2322
int batchSize, IEnumerable<long> shuffler,
2423
Device device = null,
2524
int num_worker = 1, bool drop_last = false,
@@ -34,7 +33,7 @@ public static Modules.DataLoader DataLoader(
3433
}
3534

3635
public static Modules.DataLoader DataLoader(
37-
Dataset dataset,
36+
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
3837
int batchSize, bool shuffle = false,
3938
Device device = null, int? seed = null,
4039
int num_worker = 1, bool drop_last = false,
@@ -49,7 +48,7 @@ public static Modules.DataLoader DataLoader(
4948
}
5049

5150
public static Modules.IterableDataLoader DataLoader(
52-
IterableDataset dataset,
51+
IDataset<IEnumerable<Tensor>> dataset,
5352
int batchSize, IEnumerable<long> shuffler,
5453
Device device = null,
5554
int num_worker = 1, bool drop_last = false,
@@ -64,7 +63,7 @@ public static Modules.IterableDataLoader DataLoader(
6463
}
6564

6665
public static Modules.IterableDataLoader DataLoader(
67-
IterableDataset dataset,
66+
IDataset<IEnumerable<Tensor>> dataset,
6867
int batchSize, bool shuffle = false,
6968
Device device = null, int? seed = null,
7069
int num_worker = 1, bool drop_last = false,
@@ -90,7 +89,8 @@ namespace Modules
9089
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
9190
/// </summary>
9291
/// <remarks>This class is used for map-style data sets</remarks>
93-
public class DataLoader : DataLoader<Dictionary<string, torch.Tensor>, Dictionary<string, torch.Tensor>>
92+
public class DataLoader : DataLoader<IReadOnlyDictionary<string, torch.Tensor>,
93+
Dictionary<string, torch.Tensor>>
9494
{
9595
/// <summary>
9696
/// Pytorch style dataloader
@@ -111,7 +111,7 @@ public class DataLoader : DataLoader<Dictionary<string, torch.Tensor>, Dictionar
111111
/// Indicates whether to dispose the dataset when being disposed.
112112
/// </param>
113113
public DataLoader(
114-
Dataset dataset,
114+
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
115115
int batchSize, IEnumerable<long> shuffler,
116116
Device device = null,
117117
int num_worker = 1, bool drop_last = false,
@@ -144,7 +144,7 @@ public DataLoader(
144144
/// Indicates whether to dispose the dataset when being disposed.
145145
/// </param>
146146
public DataLoader(
147-
Dataset dataset,
147+
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
148148
int batchSize, bool shuffle = false,
149149
Device device = null, int? seed = null,
150150
int num_worker = 1, bool drop_last = false,
@@ -157,7 +157,8 @@ public DataLoader(
157157
{
158158
}
159159

160-
private static Dictionary<string, torch.Tensor> Collate(IEnumerable<Dictionary<string, torch.Tensor>> dic, torch.Device device)
160+
private static Dictionary<string, torch.Tensor> Collate(
161+
IEnumerable<IReadOnlyDictionary<string, torch.Tensor>> dic, torch.Device device)
161162
{
162163
using (torch.NewDisposeScope()) {
163164
Dictionary<string, torch.Tensor> batch = new();
@@ -176,7 +177,8 @@ public DataLoader(
176177
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
177178
/// </summary>
178179
/// <remarks>This class is used for list-style data sets</remarks>
179-
public class IterableDataLoader : DataLoader<IList<torch.Tensor>, IList<torch.Tensor>>
180+
public class IterableDataLoader :
181+
DataLoader<IEnumerable<torch.Tensor>, IList<torch.Tensor>>
180182
{
181183
/// <summary>
182184
/// Pytorch style dataloader
@@ -197,7 +199,7 @@ public class IterableDataLoader : DataLoader<IList<torch.Tensor>, IList<torch.Te
197199
/// Indicates whether to dispose the dataset when being disposed.
198200
/// </param>
199201
public IterableDataLoader(
200-
IterableDataset dataset,
202+
IDataset<IEnumerable<Tensor>> dataset,
201203
int batchSize, IEnumerable<long> shuffler,
202204
Device device = null,
203205
int num_worker = 1, bool drop_last = false,
@@ -230,7 +232,7 @@ public IterableDataLoader(
230232
/// Indicates whether to dispose the dataset when being disposed.
231233
/// </param>
232234
public IterableDataLoader(
233-
IterableDataset dataset,
235+
IDataset<IEnumerable<Tensor>> dataset,
234236
int batchSize, bool shuffle = false,
235237
Device device = null, int? seed = null,
236238
int num_worker = 1, bool drop_last = false,
@@ -243,12 +245,18 @@ public IterableDataLoader(
243245
{
244246
}
245247

246-
private static IList<torch.Tensor> Collate(IEnumerable<IList<torch.Tensor>> dic, torch.Device device)
248+
private static IList<torch.Tensor> Collate(
249+
IReadOnlyList<IEnumerable<torch.Tensor>> dic, torch.Device device)
247250
{
251+
var dicCopy = new List<torch.Tensor[]>();
252+
foreach (var e in dic) {
253+
dicCopy.Add(e.ToArray());
254+
}
255+
248256
using (torch.NewDisposeScope()) {
249257
List<torch.Tensor> batch = new();
250-
for (var x = 0; x < dic.First().Count; x++) {
251-
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
258+
for (var x = 0; x < dicCopy[0].Length; x++) {
259+
var t = cat(dicCopy.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
252260
if (t.device_type != device.type || t.device_index != device.index)
253261
t = t.to(device);
254262
batch.Add(t.MoveToOuterDisposeScope());
@@ -264,12 +272,12 @@ public IterableDataLoader(
264272
/// </summary>
265273
public class DataLoader<T, S> : IEnumerable<S>, IDisposable
266274
{
267-
public Dataset<T> dataset { get; }
275+
public IDataset<T> dataset { get; }
268276
public int batch_size { get; }
269277
public bool drop_last { get; }
270278
public IEnumerable<long> sampler { get; }
271279
public int num_workers { get; }
272-
public Func<IEnumerable<T>, Device, S> collate_fn { get; }
280+
public Func<IReadOnlyList<T>, Device, S> collate_fn { get; }
273281

274282
public Device Device { get; }
275283
public bool DisposeBatch { get; }
@@ -295,9 +303,9 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
295303
/// Indicates whether to dispose the dataset when being disposed.
296304
/// </param>
297305
public DataLoader(
298-
Dataset<T> dataset,
306+
IDataset<T> dataset,
299307
int batchSize,
300-
Func<IEnumerable<T>, torch.Device, S> collate_fn,
308+
Func<IReadOnlyList<T>, torch.Device, S> collate_fn,
301309
IEnumerable<long> shuffler,
302310
Device? device = null,
303311
int num_worker = 1,
@@ -337,9 +345,9 @@ public DataLoader(
337345
/// Indicates whether to dispose the dataset when being disposed.
338346
/// </param>
339347
public DataLoader(
340-
Dataset<T> dataset,
348+
IDataset<T> dataset,
341349
int batchSize,
342-
Func<IEnumerable<T>, torch.Device, S> collate_fn,
350+
Func<IReadOnlyList<T>, torch.Device, S> collate_fn,
343351
bool shuffle = false,
344352
Device? device = null,
345353
int? seed = null,
@@ -432,7 +440,7 @@ public bool MoveNext()
432440
.WithDegreeOfParallelism(loader.num_workers)
433441
.ForAll((i) => {
434442
using var getTensorScope = torch.NewDisposeScope();
435-
tensors[i] = loader.dataset.GetTensor(indices[i]);
443+
tensors[i] = loader.dataset[indices[i]];
436444
getTensorDisposables[i] = getTensorScope.DetachAllAndDispose();
437445
});
438446

src/TorchSharp/Dataset.cs

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22
using System;
33
using System.Collections.Generic;
4+
using System.Runtime.CompilerServices;
45

56
namespace TorchSharp
67
{
@@ -13,21 +14,29 @@ public static partial class data
1314
/// <summary>
1415
/// Map-style data set
1516
/// </summary>
16-
public abstract class Dataset : Dataset<Dictionary<string, torch.Tensor>>
17+
public abstract class Dataset : Dataset<Dictionary<string, Tensor>>,
18+
IDataset<IReadOnlyDictionary<string, Tensor>>
1719
{
20+
// Due to covariation, it should naturally be IDataset<IReadOnlyDictionary<string, Tensor>>.
21+
// However FSharp.Examples will break down without this.
22+
IReadOnlyDictionary<string, Tensor> IDataset<IReadOnlyDictionary<string, Tensor>>.this[long index] => this[index];
1823
}
1924

2025
/// <summary>
2126
/// Iterable-style data sets
2227
/// </summary>
23-
public abstract class IterableDataset : Dataset<IList<Tensor>>
28+
public abstract class IterableDataset : Dataset<IList<Tensor>>,
29+
IDataset<IEnumerable<Tensor>>
2430
{
31+
// Due to covariation, it should naturally be IDataset<IEnumerable<Tensor>>.
32+
// However FSharp.Examples will break down without this.
33+
IEnumerable<Tensor> IDataset<IEnumerable<Tensor>>.this[long index] => this[index];
2534
}
2635

2736
/// <summary>
2837
/// The base nterface for all Datasets.
2938
/// </summary>
30-
public abstract class Dataset<T> : IDisposable
39+
public abstract class Dataset<T> : IDataset<T>, IDisposable
3140
{
3241
public void Dispose()
3342
{
@@ -40,6 +49,12 @@ public void Dispose()
4049
/// </summary>
4150
public abstract long Count { get; }
4251

52+
[IndexerName("DatasetItems")]
53+
public T this[long index] => this.GetTensor(index);
54+
55+
// GetTensor is kept for compatibility.
56+
// Perhaps we should remove that and make the indexer abstract later.
57+
4358
/// <summary>
4459
/// Get tensor according to index
4560
/// </summary>
@@ -49,8 +64,31 @@ public void Dispose()
4964

5065
protected virtual void Dispose(bool disposing)
5166
{
67+
IDataset<Dictionary<string, string>> a = null;
68+
IDataset<IReadOnlyDictionary<string, string>> b = a;
5269
}
5370
}
71+
72+
/// <summary>
73+
/// The base interface for all Datasets.
74+
/// </summary>
75+
public interface IDataset<out T> : IDisposable
76+
{
77+
/// <summary>
78+
/// Size of dataset
79+
/// </summary>
80+
long Count { get; }
81+
82+
/// <summary>
83+
/// Get tensor according to index
84+
/// </summary>
85+
/// <param name="index">Index for tensor</param>
86+
/// <returns>Tensors of index. DataLoader will catenate these tensors into batches.</returns>
87+
[IndexerName("DatasetItems")]
88+
T this[long index] { get; }
89+
90+
// TODO: support System.Index
91+
}
5492
}
5593
}
5694
}

0 commit comments

Comments
 (0)