Skip to content

Commit f3f14d0

Browse files
committed
Update DataLoader.cs
1 parent c31b669 commit f3f14d0

File tree

1 file changed

+93
-130
lines changed

1 file changed

+93
-130
lines changed

src/TorchSharp/DataLoader.cs

Lines changed: 93 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System;
44
using System.Collections;
55
using System.Collections.Generic;
6+
using System.Diagnostics.CodeAnalysis;
67
using System.Diagnostics.SymbolStore;
78
using System.Linq;
89
using System.Threading;
@@ -158,20 +159,21 @@ public IterableDataLoader(IterableDataset dataset, int batchSize, bool shuffle =
158159
}
159160
}
160161

162+
#nullable enable
161163
/// <summary>
162164
/// This class supports creating batches from data sets.
163165
/// </summary>
164166
public class DataLoader<T, S> : IEnumerable<S>, IDisposable
165167
{
166-
private Dataset<T> dataset;
167-
private int batchSize;
168-
private bool shuffle;
169-
private bool drop_last;
170-
private Device device;
171-
private IEnumerable<long> shuffler;
172-
private int num_worker;
173-
private Func<IEnumerable<T>, torch.Device, S> collate_fn;
174-
private bool autoDispose;
168+
private readonly Dataset<T> dataset;
169+
private readonly int batchSize;
170+
private readonly bool drop_last;
171+
private readonly Device device;
172+
private readonly IEnumerable<long> shuffler;
173+
private readonly int num_worker;
174+
private readonly Func<IEnumerable<T>, torch.Device, S> collate_fn;
175+
private readonly bool disposeBatch;
176+
private readonly bool disposeDataset;
175177

176178
/// <summary>
177179
/// Pytorch style dataloader
@@ -180,34 +182,38 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
180182
/// <param name="batchSize">Size of batch</param>
181183
/// <param name="collate_fn">Callback to merge items make to a batch</param>
182184
/// <param name="device">device for output tensor</param>
183-
/// <param name="shuffler">Shuffler for dataloader</param>
185+
/// <param name="shuffler">Shuffler for dataloader.</param>
184186
/// <param name="num_worker">Count of worker</param>
185187
/// <param name="drop_last">
186188
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
187189
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
188190
/// </param>
189-
/// <param name="autoDispose">
190-
/// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
191+
/// <param name="disposeBatch">
192+
/// Indicates whether to automatically dispose the collated tensors after an iteration.
193+
/// </param>
194+
/// <param name="disposeDataset">
195+
/// Indicates whether to dispose the dataset when being disposed.
191196
/// </param>
192197
public DataLoader(
193198
Dataset<T> dataset,
194199
int batchSize,
195200
Func<IEnumerable<T>, torch.Device, S> collate_fn,
196201
IEnumerable<long> shuffler,
197-
Device device = null,
202+
Device? device = null,
198203
int num_worker = 1,
199204
bool drop_last = false,
200-
bool autoDispose = true)
205+
bool disposeBatch = true,
206+
bool disposeDataset = true)
201207
{
202208
this.dataset = dataset;
203209
this.batchSize = batchSize;
204-
this.shuffle = true;
205210
this.drop_last = drop_last;
206211
this.device = device ?? CPU;
207212
this.shuffler = shuffler;
208-
this.num_worker = num_worker;
213+
this.num_worker = Math.Max(num_worker, 1);
209214
this.collate_fn = collate_fn;
210-
this.autoDispose = autoDispose;
215+
this.disposeBatch = disposeBatch;
216+
this.disposeDataset = disposeDataset;
211217
}
212218

213219
/// <summary>
@@ -224,39 +230,39 @@ public DataLoader(
224230
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
225231
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
226232
/// </param>
227-
/// <param name="autoDispose">
233+
/// <param name="disposeBatch">
228234
/// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
229235
/// </param>
236+
/// <param name="disposeDataset">
237+
/// Indicates whether to dispose the dataset when being disposed.
238+
/// </param>
230239
public DataLoader(
231240
Dataset<T> dataset,
232241
int batchSize,
233242
Func<IEnumerable<T>, torch.Device, S> collate_fn,
234243
bool shuffle = false,
235-
Device device = null,
244+
Device? device = null,
236245
int? seed = null,
237246
int num_worker = 1,
238247
bool drop_last = false,
239-
bool autoDispose = true)
248+
bool disposeBatch = true,
249+
bool disposeDataset = true) :
250+
this(dataset, batchSize, collate_fn,
251+
shuffle ? new FisherYatesShuffler(dataset.Count, seed) : LongRange(dataset.Count),
252+
device, num_worker, drop_last, disposeBatch, disposeDataset)
253+
{ }
254+
255+
static IEnumerable<long> LongRange(long count)
240256
{
241-
this.dataset = dataset;
242-
this.batchSize = batchSize;
243-
this.shuffle = shuffle;
244-
this.drop_last = drop_last;
245-
this.device = device ?? CPU;
246-
this.shuffler = seed is null ? new FisherYatesShuffler(dataset.Count) : new FisherYatesShuffler(dataset.Count, seed);
247-
this.num_worker = num_worker;
248-
this.collate_fn = collate_fn;
249-
this.autoDispose = autoDispose;
257+
for (long i = 0; i < count; i++)
258+
yield return i;
250259
}
251260

252261
/// <summary>
253262
/// Generate enumerator
254263
/// </summary>
255264
/// <returns>Enumerator for batch</returns>
256-
public IEnumerator<S> GetEnumerator() =>
257-
new DataLoaderEnumerator(
258-
dataset, batchSize, shuffle, device,
259-
shuffler, num_worker, collate_fn, autoDispose);
265+
public IEnumerator<S> GetEnumerator() => new DataLoaderEnumerator(this);
260266

261267
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
262268

@@ -265,50 +271,38 @@ public IEnumerator<S> GetEnumerator() =>
265271
/// </summary>
266272
public long Count => drop_last ? (dataset.Count / batchSize) : ((dataset.Count - 1) / batchSize + 1);
267273

274+
public void Dispose()
275+
{
276+
Dispose(true);
277+
GC.SuppressFinalize(this);
278+
}
279+
280+
protected virtual void Dispose(bool disposing)
281+
{
282+
if (disposing && disposeDataset) {
283+
dataset.Dispose();
284+
}
285+
}
286+
268287
private class DataLoaderEnumerator : IEnumerator<S>
269288
{
270-
private Dataset<T> dataset;
271-
private int batchSize;
272-
private Device device;
273-
private bool shuffle;
274-
private IEnumerable<long> shuffleEnumerable;
289+
private readonly DataLoader<T, S> loader;
275290
private IEnumerator<long> shuffler;
276-
private long currentVal = 0;
277-
private int num_worker = 0;
278-
private List<IDisposable> currentDisposables;
279-
private Func<IEnumerable<T>, torch.Device, S> collate_fn;
280-
public DataLoaderEnumerator(
281-
Dataset<T> dataset,
282-
int batchSize,
283-
bool shuffle,
284-
Device device,
285-
IEnumerable<long> shuffleEnumerable,
286-
int num_worker,
287-
Func<IEnumerable<T>, torch.Device, S> collate_fn,
288-
bool autoDispose)
291+
private List<IDisposable>? currentDisposables;
292+
public DataLoaderEnumerator(DataLoader<T, S> loader)
289293
{
290-
this.dataset = dataset;
291-
this.batchSize = batchSize;
292-
this.device = device;
293-
this.shuffle = shuffle;
294-
this.shuffleEnumerable = shuffleEnumerable;
295-
if (num_worker < 1) num_worker = 1;
296-
this.num_worker = num_worker;
297-
this.collate_fn = collate_fn;
298-
this.currentDisposables = autoDispose ? new List<IDisposable>() : null;
294+
this.loader = loader;
295+
if (loader.disposeBatch)
296+
this.currentDisposables = new List<IDisposable>();
299297
Reset();
300298
}
301299

302-
private bool MoveNextValue()
300+
private long? MoveNextValue()
303301
{
304-
if (shuffle) {
305-
if (!shuffler.MoveNext()) return false;
306-
currentVal = shuffler.Current;
307-
return true;
308-
} else {
309-
currentVal++;
310-
return currentVal < dataset.Count;
302+
if (!shuffler.MoveNext()) {
303+
return null;
311304
}
305+
return shuffler.Current;
312306
}
313307

314308
/// <summary>
@@ -318,107 +312,76 @@ private bool MoveNextValue()
318312
public bool MoveNext()
319313
{
320314
DisposeCurrent();
321-
using (var scope = DisposeScopeManager.NewDisposeScope()) {
322-
if (!MoveNextValue()) return false;
323315

324-
var tensorIndexList = new List<long> { currentVal };
325-
for (int i = 1; i < batchSize; i++) {
326-
if (!MoveNextValue()) break;
327-
tensorIndexList.Add(currentVal);
316+
using (var scope = torch.NewDisposeScope()) {
317+
var indices = Enumerable.Range(0, loader.batchSize)
318+
.Select(_ => MoveNextValue())
319+
.Where(x => x.HasValue)
320+
.Cast<long>()
321+
.ToArray();
322+
if (loader.drop_last && indices.Length < loader.batchSize) {
323+
return false;
328324
}
329325

330-
var items = new List<T>(new T[tensorIndexList.Count]);
331-
var taskedBatchCount = 0;
332-
333-
//Run Async
334-
var tasks = new List<Task>();
335-
foreach (var _ in Enumerable.Range(1, num_worker - 1))
336-
tasks.Add(new(ProcessPendingBatches));
337-
tasks.ForEach(x => x.Start());
338-
339-
ProcessPendingBatches();
326+
var tensors = new T[indices.Length];
327+
Enumerable.Range(0, indices.Length)
328+
.AsParallel()
329+
.WithDegreeOfParallelism(loader.num_worker)
330+
.ForAll((i) => {
331+
tensors[i] = loader.dataset.GetTensor(indices[i]);
332+
});
340333

341-
foreach (var task in tasks)
342-
task.Wait();
343-
344-
if (this.currentDisposables is not null) {
334+
if (this.currentDisposables is null) {
335+
current = loader.collate_fn(tensors, loader.device);
336+
}
337+
else {
345338
using (var collate_scope = DisposeScopeManager.NewDisposeScope()) {
346-
Current = collate_fn(items, device);
339+
current = loader.collate_fn(tensors, loader.device);
347340
currentDisposables.AddRange(collate_scope.DisposablesView);
348341
collate_scope.Detach(currentDisposables);
349342
}
350343
}
351-
else {
352-
Current = collate_fn(items, device);
353-
}
354344

355-
foreach (var item in items) {
356-
dataset.DisposeTensor(item);
345+
foreach (var item in tensors) {
346+
loader.dataset.DisposeTensor(item);
357347
}
358348

359349
return true;
360-
361-
void ProcessPendingBatches()
362-
{
363-
while (true) {
364-
var idx = ScheduleBatch();
365-
if (idx is null) break;
366-
items[idx.Value.Item1] = dataset.GetTensor(idx.Value.Item2);
367-
}
368-
}
369-
370-
(int, long)? ScheduleBatch()
371-
{
372-
var t = Interlocked.Increment(ref taskedBatchCount) - 1;
373-
if (t < tensorIndexList.Count)
374-
return (t, tensorIndexList[t]);
375-
return null;
376-
}
377350
}
378351
}
379352

380353
/// <summary>
381354
/// Reset enumerator
382355
/// </summary>
356+
[MemberNotNull(nameof(shuffler))]
383357
public void Reset()
384358
{
385359
DisposeCurrent();
386-
if (shuffle) shuffler = shuffleEnumerable.GetEnumerator();
387-
currentVal = -1;
360+
shuffler?.Dispose();
361+
shuffler = loader.shuffler.GetEnumerator();
388362
}
389363

364+
S? current;
390365
/// <summary>
391366
/// Current tensor
392367
/// </summary>
393-
public S Current { get; private set; }
368+
public S Current => current!;
394369

395-
object IEnumerator.Current => Current;
370+
object IEnumerator.Current => current!;
396371

397372
public void Dispose()
398373
{
374+
shuffler.Dispose();
399375
DisposeCurrent();
400376
}
401377

402378
private void DisposeCurrent()
403379
{
404-
if (currentDisposables is null) return;
380+
if (currentDisposables is null)
381+
return;
405382
foreach (var x in currentDisposables)
406383
x.Dispose();
407384
currentDisposables.Clear();
408-
shuffler?.Dispose();
409-
}
410-
}
411-
412-
public void Dispose()
413-
{
414-
Dispose(true);
415-
GC.SuppressFinalize(this);
416-
}
417-
418-
protected virtual void Dispose(bool disposing)
419-
{
420-
if (disposing) {
421-
dataset.Dispose();
422385
}
423386
}
424387
}

0 commit comments

Comments
 (0)