Skip to content

Commit b5be13e

Browse files
Merge branch 'main' into missing
2 parents 0212ba5 + 06a83e3 commit b5be13e

File tree

1 file changed

+24
-23
lines changed

1 file changed

+24
-23
lines changed

src/TorchSharp/DataLoader.cs

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -264,15 +264,16 @@ public IterableDataLoader(
264264
/// </summary>
265265
public class DataLoader<T, S> : IEnumerable<S>, IDisposable
266266
{
267-
private readonly Dataset<T> dataset;
268-
private readonly int batchSize;
269-
private readonly bool drop_last;
270-
private readonly Device device;
271-
private readonly IEnumerable<long> shuffler;
272-
private readonly int num_worker;
273-
private readonly Func<IEnumerable<T>, torch.Device, S> collate_fn;
274-
private readonly bool disposeBatch;
275-
private readonly bool disposeDataset;
267+
public Dataset<T> dataset { get; }
268+
public int batch_size { get; }
269+
public bool drop_last { get; }
270+
public IEnumerable<long> sampler { get; }
271+
public int num_workers { get; }
272+
public Func<IEnumerable<T>, Device, S> collate_fn { get; }
273+
274+
public Device Device { get; }
275+
public bool DisposeBatch { get; }
276+
public bool DisposeDataset { get; }
276277

277278
/// <summary>
278279
/// Pytorch style dataloader
@@ -305,14 +306,14 @@ public DataLoader(
305306
bool disposeDataset = true)
306307
{
307308
this.dataset = dataset;
308-
this.batchSize = batchSize;
309+
this.batch_size = batchSize;
309310
this.drop_last = drop_last;
310-
this.device = device ?? CPU;
311-
this.shuffler = shuffler;
312-
this.num_worker = Math.Max(num_worker, 1);
311+
this.Device = device ?? CPU;
312+
this.sampler = shuffler;
313+
this.num_workers = Math.Max(num_worker, 1);
313314
this.collate_fn = collate_fn;
314-
this.disposeBatch = disposeBatch;
315-
this.disposeDataset = disposeDataset;
315+
this.DisposeBatch = disposeBatch;
316+
this.DisposeDataset = disposeDataset;
316317
}
317318

318319
/// <summary>
@@ -368,7 +369,7 @@ static IEnumerable<long> LongRange(long count)
368369
/// <summary>
369370
/// Size of batch
370371
/// </summary>
371-
public long Count => drop_last ? (dataset.Count / batchSize) : ((dataset.Count - 1) / batchSize + 1);
372+
public long Count => drop_last ? (dataset.Count / batch_size) : ((dataset.Count - 1) / batch_size + 1);
372373

373374
public void Dispose()
374375
{
@@ -378,7 +379,7 @@ public void Dispose()
378379

379380
protected virtual void Dispose(bool disposing)
380381
{
381-
if (disposing && disposeDataset) {
382+
if (disposing && DisposeDataset) {
382383
dataset.Dispose();
383384
}
384385
}
@@ -411,7 +412,7 @@ public bool MoveNext()
411412
{
412413
DisposeCurrent();
413414

414-
var indices = Enumerable.Range(0, loader.batchSize)
415+
var indices = Enumerable.Range(0, loader.batch_size)
415416
.Select(_ => shuffler.MoveNext() ? shuffler.Current : (long?)null)
416417
.Where(x => x.HasValue)
417418
.Cast<long>()
@@ -420,25 +421,25 @@ public bool MoveNext()
420421
if (indices.Length is 0)
421422
return false;
422423

423-
if (loader.drop_last && indices.Length < loader.batchSize) {
424+
if (loader.drop_last && indices.Length < loader.batch_size) {
424425
return false;
425426
}
426427

427428
var tensors = new T[indices.Length];
428429
var getTensorDisposables = new HashSet<IDisposable>[indices.Length];
429430
Enumerable.Range(0, indices.Length)
430431
.AsParallel()
431-
.WithDegreeOfParallelism(loader.num_worker)
432+
.WithDegreeOfParallelism(loader.num_workers)
432433
.ForAll((i) => {
433434
using var getTensorScope = torch.NewDisposeScope();
434435
tensors[i] = loader.dataset.GetTensor(indices[i]);
435436
getTensorDisposables[i] = getTensorScope.DetachAllAndDispose();
436437
});
437438

438439
using var collateScope = torch.NewDisposeScope();
439-
this.current = loader.collate_fn(tensors, loader.device);
440+
this.current = loader.collate_fn(tensors, loader.Device);
440441
var collateDisposables = collateScope.DetachAllAndDispose();
441-
if (loader.disposeBatch) {
442+
if (loader.DisposeBatch) {
442443
this.currentDisposables = collateDisposables;
443444
}
444445

@@ -456,7 +457,7 @@ public void Reset()
456457
{
457458
DisposeCurrent();
458459
shuffler?.Dispose();
459-
shuffler = loader.shuffler.GetEnumerator();
460+
shuffler = loader.sampler.GetEnumerator();
460461
}
461462

462463
S? current;

0 commit comments

Comments
 (0)