Skip to content

Commit 0009e27

Browse files
committed
no need for DisposeTensor
1 parent f3f14d0 commit 0009e27

File tree

5 files changed

+16
-49
lines changed

5 files changed

+16
-49
lines changed

RELEASENOTES.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,13 @@ __Bug Fixes__:
1818

1919
- #1300 `Adadelta`, `Adam` and `AdamW` will no longer throw `NullReferenceException` when `maximize` is `true` and `grad` is `null`.
2020
- `torch.normal` will now correctly return a leaf tensor.
21-
- A new option `autoDispose` has been added into `DataLoader`s, which indicates whether to dispose the collated tensors before the next iteration.
21+
- New options `disposeBatch` and `disposeDataset` have been added into `DataLoader`.
2222
- The default collate functions will now always dispose the intermediate tensors, rather than wait for the next iteration.
23-
- A new abstract method `DisposeTensor` has been added to `Dataset<>`s.
24-
- This method will be used by `DataLoader`, to dispose the values provided by `GetTensor`.
25-
- `Dataset` and `IterableDataset` has implemented this method, so if your dataset is inherited from them, please check whether the disposal should be avoided in your case.
2623

2724
__Bug Fixes__:
2825

2926
- `TensorDataset` will now keep the aliases detached from dispose scopes, to avoid the unexpected disposal.
27+
- `DataLoaderEnumerator` has been completely rewritten to resolve the unexpected shuffler disposal, the ignorance of drop last and the incorrect count of worker.
3028

3129
# NuGet Version 0.102.4
3230

src/TorchAudio/Datasets/SpeechCommandsDataset.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,6 @@ internal SpeechCommandsDataset(string path, string subset)
4848

4949
public override long Count => _walker.LongLength;
5050

51-
public override void DisposeTensor(SpeechCommandsDatasetItem tensor)
52-
{
53-
tensor.waveform.Dispose();
54-
}
55-
5651
public override SpeechCommandsDatasetItem GetTensor(long index)
5752
{
5853
var audioPath = _walker[index];

src/TorchAudio/Datasets/YesnoDataset.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,6 @@ public YesnoDataset(string directoryPathInArchive)
2222

2323
public override long Count => audioPathList.LongLength;
2424

25-
public override void DisposeTensor(YesnoDatasetItem tensor)
26-
{
27-
tensor.waveform.Dispose();
28-
}
29-
3025
public override YesnoDatasetItem GetTensor(long index)
3126
{
3227
var (waveform, sample_rate) = torchaudio.load(audioPathList[index]);

src/TorchSharp/DataLoader.cs

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -284,16 +284,17 @@ protected virtual void Dispose(bool disposing)
284284
}
285285
}
286286

287-
private class DataLoaderEnumerator : IEnumerator<S>
287+
sealed class DataLoaderEnumerator : IEnumerator<S>
288288
{
289289
private readonly DataLoader<T, S> loader;
290290
private IEnumerator<long> shuffler;
291-
private List<IDisposable>? currentDisposables;
291+
private IReadOnlyList<IDisposable> currentDisposables;
292292
public DataLoaderEnumerator(DataLoader<T, S> loader)
293293
{
294294
this.loader = loader;
295-
if (loader.disposeBatch)
296-
this.currentDisposables = new List<IDisposable>();
295+
this.currentDisposables = Array.Empty<IDisposable>();
296+
// TODO: Use MemberNotNull instead.
297+
shuffler = null!;
297298
Reset();
298299
}
299300

@@ -331,19 +332,14 @@ public bool MoveNext()
331332
tensors[i] = loader.dataset.GetTensor(indices[i]);
332333
});
333334

334-
if (this.currentDisposables is null) {
335-
current = loader.collate_fn(tensors, loader.device);
336-
}
337-
else {
338-
using (var collate_scope = DisposeScopeManager.NewDisposeScope()) {
339-
current = loader.collate_fn(tensors, loader.device);
340-
currentDisposables.AddRange(collate_scope.DisposablesView);
341-
collate_scope.Detach(currentDisposables);
342-
}
343-
}
335+
using var collate_scope = DisposeScopeManager.NewDisposeScope();
336+
current = loader.collate_fn(tensors, loader.device);
344337

345-
foreach (var item in tensors) {
346-
loader.dataset.DisposeTensor(item);
338+
// TODO: Will be better if we have something like DetachAll
339+
var view = collate_scope.DisposablesView;
340+
collate_scope.Detach(view);
341+
if (loader.disposeBatch) {
342+
this.currentDisposables = view;
347343
}
348344

349345
return true;
@@ -353,7 +349,6 @@ public bool MoveNext()
353349
/// <summary>
354350
/// Reset enumerator
355351
/// </summary>
356-
[MemberNotNull(nameof(shuffler))]
357352
public void Reset()
358353
{
359354
DisposeCurrent();
@@ -377,11 +372,9 @@ public void Dispose()
377372

378373
private void DisposeCurrent()
379374
{
380-
if (currentDisposables is null)
381-
return;
382-
foreach (var x in currentDisposables)
375+
foreach (var x in this.currentDisposables)
383376
x.Dispose();
384-
currentDisposables.Clear();
377+
this.currentDisposables = Array.Empty<IDisposable>();
385378
}
386379
}
387380
}

src/TorchSharp/Dataset.cs

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,13 @@ public static partial class data
1515
/// </summary>
1616
public abstract class Dataset : Dataset<Dictionary<string, torch.Tensor>>
1717
{
18-
public override void DisposeTensor(Dictionary<string, Tensor> tensor)
19-
{
20-
foreach (var t in tensor.Values) {
21-
t.Dispose();
22-
}
23-
}
2418
}
2519

2620
/// <summary>
2721
/// Iterable-style data sets
2822
/// </summary>
2923
public abstract class IterableDataset : Dataset<IList<Tensor>>
3024
{
31-
public override void DisposeTensor(IList<Tensor> tensor)
32-
{
33-
foreach (var t in tensor) {
34-
t.Dispose();
35-
}
36-
}
3725
}
3826

3927
/// <summary>
@@ -59,8 +47,6 @@ public void Dispose()
5947
/// <returns>Tensors of index. DataLoader will catenate these tensors into batches.</returns>
6048
public abstract T GetTensor(long index);
6149

62-
public abstract void DisposeTensor(T tensor);
63-
6450
protected virtual void Dispose(bool disposing)
6551
{
6652
}

0 commit comments

Comments
 (0)