Skip to content

Commit aedf161

Browse files
committed
IDataset interface
1 parent c332e9d commit aedf161

File tree

3 files changed

+62
-34
lines changed

3 files changed

+62
-34
lines changed

src/TorchSharp/DataLoader.cs

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ public static partial class utils
1717
{
1818
public static partial class data
1919
{
20-
2120
public static Modules.DataLoader DataLoader(
22-
Dataset dataset,
21+
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
2322
int batchSize, IEnumerable<long> shuffler,
2423
Device device = null,
2524
int num_worker = 1, bool drop_last = false,
@@ -34,7 +33,7 @@ public static Modules.DataLoader DataLoader(
3433
}
3534

3635
public static Modules.DataLoader DataLoader(
37-
Dataset dataset,
36+
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
3837
int batchSize, bool shuffle = false,
3938
Device device = null, int? seed = null,
4039
int num_worker = 1, bool drop_last = false,
@@ -49,7 +48,7 @@ public static Modules.DataLoader DataLoader(
4948
}
5049

5150
public static Modules.IterableDataLoader DataLoader(
52-
IterableDataset dataset,
51+
IDataset<IEnumerable<Tensor>> dataset,
5352
int batchSize, IEnumerable<long> shuffler,
5453
Device device = null,
5554
int num_worker = 1, bool drop_last = false,
@@ -64,7 +63,7 @@ public static Modules.IterableDataLoader DataLoader(
6463
}
6564

6665
public static Modules.IterableDataLoader DataLoader(
67-
IterableDataset dataset,
66+
IDataset<IEnumerable<Tensor>> dataset,
6867
int batchSize, bool shuffle = false,
6968
Device device = null, int? seed = null,
7069
int num_worker = 1, bool drop_last = false,
@@ -90,7 +89,8 @@ namespace Modules
9089
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
9190
/// </summary>
9291
/// <remarks>This class is used for map-style data sets</remarks>
93-
public class DataLoader : DataLoader<Dictionary<string, torch.Tensor>, Dictionary<string, torch.Tensor>>
92+
public class DataLoader : DataLoader<IReadOnlyDictionary<string, torch.Tensor>,
93+
Dictionary<string, torch.Tensor>>
9494
{
9595
/// <summary>
9696
/// Pytorch style dataloader
@@ -111,7 +111,7 @@ public class DataLoader : DataLoader<Dictionary<string, torch.Tensor>, Dictionar
111111
/// Indicates whether to dispose the dataset when being disposed.
112112
/// </param>
113113
public DataLoader(
114-
Dataset dataset,
114+
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
115115
int batchSize, IEnumerable<long> shuffler,
116116
Device device = null,
117117
int num_worker = 1, bool drop_last = false,
@@ -144,7 +144,7 @@ public DataLoader(
144144
/// Indicates whether to dispose the dataset when being disposed.
145145
/// </param>
146146
public DataLoader(
147-
Dataset dataset,
147+
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
148148
int batchSize, bool shuffle = false,
149149
Device device = null, int? seed = null,
150150
int num_worker = 1, bool drop_last = false,
@@ -157,7 +157,8 @@ public DataLoader(
157157
{
158158
}
159159

160-
private static Dictionary<string, torch.Tensor> Collate(IEnumerable<Dictionary<string, torch.Tensor>> dic, torch.Device device)
160+
private static Dictionary<string, torch.Tensor> Collate(
161+
IEnumerable<IReadOnlyDictionary<string, torch.Tensor>> dic, torch.Device device)
161162
{
162163
using (torch.NewDisposeScope()) {
163164
Dictionary<string, torch.Tensor> batch = new();
@@ -176,7 +177,8 @@ public DataLoader(
176177
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
177178
/// </summary>
178179
/// <remarks>This class is used for list-style data sets</remarks>
179-
public class IterableDataLoader : DataLoader<IList<torch.Tensor>, IList<torch.Tensor>>
180+
public class IterableDataLoader :
181+
DataLoader<IEnumerable<torch.Tensor>, IList<torch.Tensor>>
180182
{
181183
/// <summary>
182184
/// Pytorch style dataloader
@@ -197,7 +199,7 @@ public class IterableDataLoader : DataLoader<IList<torch.Tensor>, IList<torch.Te
197199
/// Indicates whether to dispose the dataset when being disposed.
198200
/// </param>
199201
public IterableDataLoader(
200-
IterableDataset dataset,
202+
IDataset<IEnumerable<Tensor>> dataset,
201203
int batchSize, IEnumerable<long> shuffler,
202204
Device device = null,
203205
int num_worker = 1, bool drop_last = false,
@@ -230,7 +232,7 @@ public IterableDataLoader(
230232
/// Indicates whether to dispose the dataset when being disposed.
231233
/// </param>
232234
public IterableDataLoader(
233-
IterableDataset dataset,
235+
IDataset<IEnumerable<Tensor>> dataset,
234236
int batchSize, bool shuffle = false,
235237
Device device = null, int? seed = null,
236238
int num_worker = 1, bool drop_last = false,
@@ -243,12 +245,18 @@ public IterableDataLoader(
243245
{
244246
}
245247

246-
private static IList<torch.Tensor> Collate(IEnumerable<IList<torch.Tensor>> dic, torch.Device device)
248+
private static IList<torch.Tensor> Collate(
249+
IReadOnlyList<IEnumerable<torch.Tensor>> dic, torch.Device device)
247250
{
251+
var dicCopy = new List<torch.Tensor[]>();
252+
foreach (var e in dic) {
253+
dicCopy.Add(e.ToArray());
254+
}
255+
248256
using (torch.NewDisposeScope()) {
249257
List<torch.Tensor> batch = new();
250-
for (var x = 0; x < dic.First().Count; x++) {
251-
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
258+
for (var x = 0; x < dicCopy[0].Length; x++) {
259+
var t = cat(dicCopy.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
252260
if (t.device_type != device.type || t.device_index != device.index)
253261
t = t.to(device);
254262
batch.Add(t.MoveToOuterDisposeScope());
@@ -264,12 +272,12 @@ public IterableDataLoader(
264272
/// </summary>
265273
public class DataLoader<T, S> : IEnumerable<S>, IDisposable
266274
{
267-
public Dataset<T> dataset { get; }
275+
public IDataset<T> dataset { get; }
268276
public int batch_size { get; }
269277
public bool drop_last { get; }
270278
public IEnumerable<long> sampler { get; }
271279
public int num_workers { get; }
272-
public Func<IEnumerable<T>, Device, S> collate_fn { get; }
280+
public Func<IReadOnlyList<T>, Device, S> collate_fn { get; }
273281

274282
public Device Device { get; }
275283
public bool DisposeBatch { get; }
@@ -295,9 +303,9 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
295303
/// Indicates whether to dispose the dataset when being disposed.
296304
/// </param>
297305
public DataLoader(
298-
Dataset<T> dataset,
306+
IDataset<T> dataset,
299307
int batchSize,
300-
Func<IEnumerable<T>, torch.Device, S> collate_fn,
308+
Func<IReadOnlyList<T>, torch.Device, S> collate_fn,
301309
IEnumerable<long> shuffler,
302310
Device? device = null,
303311
int num_worker = 1,
@@ -337,9 +345,9 @@ public DataLoader(
337345
/// Indicates whether to dispose the dataset when being disposed.
338346
/// </param>
339347
public DataLoader(
340-
Dataset<T> dataset,
348+
IDataset<T> dataset,
341349
int batchSize,
342-
Func<IEnumerable<T>, torch.Device, S> collate_fn,
350+
Func<IReadOnlyList<T>, torch.Device, S> collate_fn,
343351
bool shuffle = false,
344352
Device? device = null,
345353
int? seed = null,
@@ -432,7 +440,7 @@ public bool MoveNext()
432440
.WithDegreeOfParallelism(loader.num_workers)
433441
.ForAll((i) => {
434442
using var getTensorScope = torch.NewDisposeScope();
435-
tensors[i] = loader.dataset.GetTensor(indices[i]);
443+
tensors[i] = loader.dataset[indices[i]];
436444
getTensorDisposables[i] = getTensorScope.DetachAllAndDispose();
437445
});
438446

src/TorchSharp/Dataset.cs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22
using System;
33
using System.Collections.Generic;
4+
using System.Runtime.CompilerServices;
45

56
namespace TorchSharp
67
{
@@ -27,7 +28,7 @@ public abstract class IterableDataset : Dataset<IList<Tensor>>
2728
/// <summary>
2829
/// The base nterface for all Datasets.
2930
/// </summary>
30-
public abstract class Dataset<T> : IDisposable
31+
public abstract class Dataset<T> : IDataset<T>, IDisposable
3132
{
3233
public void Dispose()
3334
{
@@ -40,6 +41,12 @@ public void Dispose()
4041
/// </summary>
4142
public abstract long Count { get; }
4243

44+
[IndexerName("DatasetItems")]
45+
public T this[long index] => this.GetTensor(index);
46+
47+
// GetTensor is kept for compatibility.
48+
// Perhaps we should remove that and make the indexer abstract later.
49+
4350
/// <summary>
4451
/// Get tensor according to index
4552
/// </summary>
@@ -49,8 +56,31 @@ public void Dispose()
4956

5057
protected virtual void Dispose(bool disposing)
5158
{
59+
IDataset<Dictionary<string, string>> a = null;
60+
IDataset<IReadOnlyDictionary<string, string>> b = a;
5261
}
5362
}
63+
64+
/// <summary>
65+
/// The base interface for all Datasets.
66+
/// </summary>
67+
public interface IDataset<out T> : IDisposable
68+
{
69+
/// <summary>
70+
/// Size of dataset
71+
/// </summary>
72+
long Count { get; }
73+
74+
/// <summary>
75+
/// Get tensor according to index
76+
/// </summary>
77+
/// <param name="index">Index for tensor</param>
78+
/// <returns>Tensors of index. DataLoader will catenate these tensors into batches.</returns>
79+
[IndexerName("DatasetItems")]
80+
T this[long index] { get; }
81+
82+
// TODO: support System.Index
83+
}
5484
}
5585
}
5686
}

src/TorchSharp/Utils/TensorDataset.cs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,6 @@ internal TensorDataset(torch.Tensor[] tensors)
4040
_tensors = tensors.Select(x => x.alias().DetachFromDisposeScope()).ToArray();
4141
}
4242

43-
/// <summary>
44-
/// Indexer
45-
/// </summary>
46-
public IList<torch.Tensor> this[long index] {
47-
48-
get {
49-
return _tensors.Select(t => t[index]).ToList();
50-
}
51-
}
52-
5343
/// <summary>
5444
/// Length of the dataset, i.e. the size of the first dimension.
5545
/// </summary>
@@ -59,7 +49,7 @@ public override long Count {
5949

6050
public override IList<torch.Tensor> GetTensor(long index)
6151
{
62-
return this[index];
52+
return _tensors.Select(t => t[index]).ToList();
6353
}
6454

6555
readonly torch.Tensor[] _tensors;

0 commit comments

Comments
 (0)