Skip to content

Commit 26240d5

Browse files
Merge pull request #1358 from dotnet/revert-1354-ConcatDataset
Revert "concat dataset"
2 parents 35f4c7d + 2c3f40f commit 26240d5

File tree

6 files changed

+36
-281
lines changed

6 files changed

+36
-281
lines changed

RELEASENOTES.md

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,6 @@
22

33
Releases, starting with 9/2/2021, are listed with the most recent release at the top.
44

5-
# NuGet Version 0.102.7
6-
7-
__Breaking Changes__:
8-
9-
A new interface `IDataset<out T>` has been added. (Now `Dataset<T>` implements `IDataset<T>`; `Dataset` implements both `IDataset<Dictionary<string, Tensor>>` and `IDataset<IReadOnlyDictionary<string, Tensor>>`; `IterableDataset` implements `IDataset<IList<string, Tensor>>` and `IDataset<IEnumerable<string, Tensor>>`.)<br/>
10-
`torch.utils.data.ConcatDataset` has been added.<br/>
11-
12-
__API Changes__:
13-
14-
The parameter of `DataLoader`s has been relaxed to `IDataset`.<br/>
15-
The parameter of `DataLoader`s' collate functions has been relaxed to `IReadOnlyList`.<br/>
16-
175
# NuGet Version 0.102.6
186

197
__Breaking Changes__:

src/TorchSharp/ConcatDataset.cs

Lines changed: 0 additions & 108 deletions
This file was deleted.

src/TorchSharp/DataLoader.cs

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ public static partial class utils
1717
{
1818
public static partial class data
1919
{
20+
2021
public static Modules.DataLoader DataLoader(
21-
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
22+
Dataset dataset,
2223
int batchSize, IEnumerable<long> shuffler,
2324
Device device = null,
2425
int num_worker = 1, bool drop_last = false,
@@ -33,7 +34,7 @@ public static Modules.DataLoader DataLoader(
3334
}
3435

3536
public static Modules.DataLoader DataLoader(
36-
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
37+
Dataset dataset,
3738
int batchSize, bool shuffle = false,
3839
Device device = null, int? seed = null,
3940
int num_worker = 1, bool drop_last = false,
@@ -48,7 +49,7 @@ public static Modules.DataLoader DataLoader(
4849
}
4950

5051
public static Modules.IterableDataLoader DataLoader(
51-
IDataset<IEnumerable<Tensor>> dataset,
52+
IterableDataset dataset,
5253
int batchSize, IEnumerable<long> shuffler,
5354
Device device = null,
5455
int num_worker = 1, bool drop_last = false,
@@ -63,7 +64,7 @@ public static Modules.IterableDataLoader DataLoader(
6364
}
6465

6566
public static Modules.IterableDataLoader DataLoader(
66-
IDataset<IEnumerable<Tensor>> dataset,
67+
IterableDataset dataset,
6768
int batchSize, bool shuffle = false,
6869
Device device = null, int? seed = null,
6970
int num_worker = 1, bool drop_last = false,
@@ -89,8 +90,7 @@ namespace Modules
8990
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
9091
/// </summary>
9192
/// <remarks>This class is used for map-style data sets</remarks>
92-
public class DataLoader : DataLoader<IReadOnlyDictionary<string, torch.Tensor>,
93-
Dictionary<string, torch.Tensor>>
93+
public class DataLoader : DataLoader<Dictionary<string, torch.Tensor>, Dictionary<string, torch.Tensor>>
9494
{
9595
/// <summary>
9696
/// Pytorch style dataloader
@@ -111,7 +111,7 @@ public class DataLoader : DataLoader<IReadOnlyDictionary<string, torch.Tensor>,
111111
/// Indicates whether to dispose the dataset when being disposed.
112112
/// </param>
113113
public DataLoader(
114-
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
114+
Dataset dataset,
115115
int batchSize, IEnumerable<long> shuffler,
116116
Device device = null,
117117
int num_worker = 1, bool drop_last = false,
@@ -144,7 +144,7 @@ public DataLoader(
144144
/// Indicates whether to dispose the dataset when being disposed.
145145
/// </param>
146146
public DataLoader(
147-
IDataset<IReadOnlyDictionary<string, torch.Tensor>> dataset,
147+
Dataset dataset,
148148
int batchSize, bool shuffle = false,
149149
Device device = null, int? seed = null,
150150
int num_worker = 1, bool drop_last = false,
@@ -157,8 +157,7 @@ public DataLoader(
157157
{
158158
}
159159

160-
private static Dictionary<string, torch.Tensor> Collate(
161-
IEnumerable<IReadOnlyDictionary<string, torch.Tensor>> dic, torch.Device device)
160+
private static Dictionary<string, torch.Tensor> Collate(IEnumerable<Dictionary<string, torch.Tensor>> dic, torch.Device device)
162161
{
163162
using (torch.NewDisposeScope()) {
164163
Dictionary<string, torch.Tensor> batch = new();
@@ -177,8 +176,7 @@ public DataLoader(
177176
/// Data loader. Combines a dataset and a sampler, and provides an enumerator over the given dataset.
178177
/// </summary>
179178
/// <remarks>This class is used for list-style data sets</remarks>
180-
public class IterableDataLoader :
181-
DataLoader<IEnumerable<torch.Tensor>, IList<torch.Tensor>>
179+
public class IterableDataLoader : DataLoader<IList<torch.Tensor>, IList<torch.Tensor>>
182180
{
183181
/// <summary>
184182
/// Pytorch style dataloader
@@ -199,7 +197,7 @@ public class IterableDataLoader :
199197
/// Indicates whether to dispose the dataset when being disposed.
200198
/// </param>
201199
public IterableDataLoader(
202-
IDataset<IEnumerable<Tensor>> dataset,
200+
IterableDataset dataset,
203201
int batchSize, IEnumerable<long> shuffler,
204202
Device device = null,
205203
int num_worker = 1, bool drop_last = false,
@@ -232,7 +230,7 @@ public IterableDataLoader(
232230
/// Indicates whether to dispose the dataset when being disposed.
233231
/// </param>
234232
public IterableDataLoader(
235-
IDataset<IEnumerable<Tensor>> dataset,
233+
IterableDataset dataset,
236234
int batchSize, bool shuffle = false,
237235
Device device = null, int? seed = null,
238236
int num_worker = 1, bool drop_last = false,
@@ -245,18 +243,12 @@ public IterableDataLoader(
245243
{
246244
}
247245

248-
private static IList<torch.Tensor> Collate(
249-
IReadOnlyList<IEnumerable<torch.Tensor>> dic, torch.Device device)
246+
private static IList<torch.Tensor> Collate(IEnumerable<IList<torch.Tensor>> dic, torch.Device device)
250247
{
251-
var dicCopy = new List<torch.Tensor[]>();
252-
foreach (var e in dic) {
253-
dicCopy.Add(e.ToArray());
254-
}
255-
256248
using (torch.NewDisposeScope()) {
257249
List<torch.Tensor> batch = new();
258-
for (var x = 0; x < dicCopy[0].Length; x++) {
259-
var t = cat(dicCopy.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
250+
for (var x = 0; x < dic.First().Count; x++) {
251+
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
260252
if (t.device_type != device.type || t.device_index != device.index)
261253
t = t.to(device);
262254
batch.Add(t.MoveToOuterDisposeScope());
@@ -272,12 +264,12 @@ public IterableDataLoader(
272264
/// </summary>
273265
public class DataLoader<T, S> : IEnumerable<S>, IDisposable
274266
{
275-
public IDataset<T> dataset { get; }
267+
public Dataset<T> dataset { get; }
276268
public int batch_size { get; }
277269
public bool drop_last { get; }
278270
public IEnumerable<long> sampler { get; }
279271
public int num_workers { get; }
280-
public Func<IReadOnlyList<T>, Device, S> collate_fn { get; }
272+
public Func<IEnumerable<T>, Device, S> collate_fn { get; }
281273

282274
public Device Device { get; }
283275
public bool DisposeBatch { get; }
@@ -303,9 +295,9 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
303295
/// Indicates whether to dispose the dataset when being disposed.
304296
/// </param>
305297
public DataLoader(
306-
IDataset<T> dataset,
298+
Dataset<T> dataset,
307299
int batchSize,
308-
Func<IReadOnlyList<T>, torch.Device, S> collate_fn,
300+
Func<IEnumerable<T>, torch.Device, S> collate_fn,
309301
IEnumerable<long> shuffler,
310302
Device? device = null,
311303
int num_worker = 1,
@@ -345,9 +337,9 @@ public DataLoader(
345337
/// Indicates whether to dispose the dataset when being disposed.
346338
/// </param>
347339
public DataLoader(
348-
IDataset<T> dataset,
340+
Dataset<T> dataset,
349341
int batchSize,
350-
Func<IReadOnlyList<T>, torch.Device, S> collate_fn,
342+
Func<IEnumerable<T>, torch.Device, S> collate_fn,
351343
bool shuffle = false,
352344
Device? device = null,
353345
int? seed = null,
@@ -440,7 +432,7 @@ public bool MoveNext()
440432
.WithDegreeOfParallelism(loader.num_workers)
441433
.ForAll((i) => {
442434
using var getTensorScope = torch.NewDisposeScope();
443-
tensors[i] = loader.dataset[indices[i]];
435+
tensors[i] = loader.dataset.GetTensor(indices[i]);
444436
getTensorDisposables[i] = getTensorScope.DetachAllAndDispose();
445437
});
446438

src/TorchSharp/Dataset.cs

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22
using System;
33
using System.Collections.Generic;
4-
using System.Runtime.CompilerServices;
54

65
namespace TorchSharp
76
{
@@ -14,29 +13,21 @@ public static partial class data
1413
/// <summary>
1514
/// Map-style data set
1615
/// </summary>
17-
public abstract class Dataset : Dataset<Dictionary<string, Tensor>>,
18-
IDataset<IReadOnlyDictionary<string, Tensor>>
16+
public abstract class Dataset : Dataset<Dictionary<string, torch.Tensor>>
1917
{
20-
// Due to covariation, it should naturally be IDataset<IReadOnlyDictionary<string, Tensor>>.
21-
// However FSharp.Examples will break down without this.
22-
IReadOnlyDictionary<string, Tensor> IDataset<IReadOnlyDictionary<string, Tensor>>.this[long index] => this[index];
2318
}
2419

2520
/// <summary>
2621
/// Iterable-style data sets
2722
/// </summary>
28-
public abstract class IterableDataset : Dataset<IList<Tensor>>,
29-
IDataset<IEnumerable<Tensor>>
23+
public abstract class IterableDataset : Dataset<IList<Tensor>>
3024
{
31-
// Due to covariation, it should naturally be IDataset<IEnumerable<Tensor>>.
32-
// However FSharp.Examples will break down without this.
33-
IEnumerable<Tensor> IDataset<IEnumerable<Tensor>>.this[long index] => this[index];
3425
}
3526

3627
/// <summary>
3728
/// The base nterface for all Datasets.
3829
/// </summary>
39-
public abstract class Dataset<T> : IDataset<T>, IDisposable
30+
public abstract class Dataset<T> : IDisposable
4031
{
4132
public void Dispose()
4233
{
@@ -49,12 +40,6 @@ public void Dispose()
4940
/// </summary>
5041
public abstract long Count { get; }
5142

52-
[IndexerName("DatasetItems")]
53-
public T this[long index] => this.GetTensor(index);
54-
55-
// GetTensor is kept for compatibility.
56-
// Perhaps we should remove that and make the indexer abstract later.
57-
5843
/// <summary>
5944
/// Get tensor according to index
6045
/// </summary>
@@ -64,31 +49,8 @@ public void Dispose()
6449

6550
protected virtual void Dispose(bool disposing)
6651
{
67-
IDataset<Dictionary<string, string>> a = null;
68-
IDataset<IReadOnlyDictionary<string, string>> b = a;
6952
}
7053
}
71-
72-
/// <summary>
73-
/// The base interface for all Datasets.
74-
/// </summary>
75-
public interface IDataset<out T> : IDisposable
76-
{
77-
/// <summary>
78-
/// Size of dataset
79-
/// </summary>
80-
long Count { get; }
81-
82-
/// <summary>
83-
/// Get tensor according to index
84-
/// </summary>
85-
/// <param name="index">Index for tensor</param>
86-
/// <returns>Tensors of index. DataLoader will catenate these tensors into batches.</returns>
87-
[IndexerName("DatasetItems")]
88-
T this[long index] { get; }
89-
90-
// TODO: support System.Index
91-
}
9254
}
9355
}
9456
}

0 commit comments

Comments
 (0)