Skip to content

Commit 106a74d

Browse files
committed
separate scope
1 parent 8796ee9 commit 106a74d

File tree

1 file changed

+39
-34
lines changed

1 file changed

+39
-34
lines changed

src/TorchSharp/DataLoader.cs

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -396,12 +396,11 @@ public DataLoaderEnumerator(DataLoader<T, S> loader)
396396
Reset();
397397
}
398398

399-
private long? MoveNextValue()
399+
private static void DisposeAll(HashSet<IDisposable> disposables)
400400
{
401-
if (!shuffler.MoveNext()) {
402-
return null;
401+
foreach (var disposable in disposables) {
402+
disposable.Dispose();
403403
}
404-
return shuffler.Current;
405404
}
406405

407406
/// <summary>
@@ -412,35 +411,42 @@ public bool MoveNext()
412411
{
413412
DisposeCurrent();
414413

415-
using (var scope = torch.NewDisposeScope()) {
416-
var indices = Enumerable.Range(0, loader.batchSize)
417-
.Select(_ => MoveNextValue())
418-
.Where(x => x.HasValue)
419-
.Cast<long>()
420-
.ToArray();
421-
if (indices.Length is 0)
422-
return false;
423-
if (loader.drop_last && indices.Length < loader.batchSize) {
424-
return false;
425-
}
426-
427-
var tensors = new T[indices.Length];
428-
Enumerable.Range(0, indices.Length)
429-
.AsParallel()
430-
.WithDegreeOfParallelism(loader.num_worker)
431-
.ForAll((i) => {
432-
tensors[i] = loader.dataset.GetTensor(indices[i]);
433-
});
434-
435-
using var collate_scope = torch.NewDisposeScope();
436-
current = loader.collate_fn(tensors, loader.device);
437-
var disposables = collate_scope.DetachAllAndDispose();
438-
if (loader.disposeBatch) {
439-
this.currentDisposables = disposables;
440-
}
441-
442-
return true;
414+
var indices = Enumerable.Range(0, loader.batchSize)
415+
.Select(_ => shuffler.MoveNext() ? shuffler.Current : (long?)null)
416+
.Where(x => x.HasValue)
417+
.Cast<long>()
418+
.ToArray();
419+
420+
if (indices.Length is 0)
421+
return false;
422+
423+
if (loader.drop_last && indices.Length < loader.batchSize) {
424+
return false;
425+
}
426+
427+
var tensors = new T[indices.Length];
428+
var getTensorDisposables = new HashSet<IDisposable>[indices.Length];
429+
Enumerable.Range(0, indices.Length)
430+
.AsParallel()
431+
.WithDegreeOfParallelism(loader.num_worker)
432+
.ForAll((i) => {
433+
using var getTensorScope = torch.NewDisposeScope();
434+
tensors[i] = loader.dataset.GetTensor(indices[i]);
435+
getTensorDisposables[i] = getTensorScope.DetachAllAndDispose();
436+
});
437+
438+
using var collateScope = torch.NewDisposeScope();
439+
this.current = loader.collate_fn(tensors, loader.device);
440+
var collateDisposables = collateScope.DetachAllAndDispose();
441+
if (loader.disposeBatch) {
442+
this.currentDisposables = collateDisposables;
443443
}
444+
445+
foreach (var set in getTensorDisposables) {
446+
DisposeAll(set);
447+
}
448+
449+
return true;
444450
}
445451

446452
/// <summary>
@@ -470,8 +476,7 @@ public void Dispose()
470476
private void DisposeCurrent()
471477
{
472478
if (this.currentDisposables is not null) {
473-
foreach (var x in this.currentDisposables)
474-
x.Dispose();
479+
DisposeAll(this.currentDisposables);
475480
this.currentDisposables = null;
476481
}
477482
}

0 commit comments

Comments
 (0)