Skip to content

Commit cfa9fde

Browse files
authored
Merge branch 'collate' into temp
2 parents 8b14055 + 3042d5f commit cfa9fde

File tree

6 files changed

+117
-26
lines changed

6 files changed

+117
-26
lines changed

RELEASENOTES.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ __Bug Fixes__:
1818

1919
- #1300 `Adadelta`, `Adam` and `AdamW` will no longer throw `NullReferenceException` when `maximize` is `true` and `grad` is `null`.
2020
- `torch.normal` will now correctly return a leaf tensor.
21+
- A new option `autoDispose` has been added into `DataLoader`s, which indicates whether to dispose the collated tensors before the next iteration.
22+
- The default collate functions will now always dispose the intermediate tensors, rather than wait for the next iteration.
23+
- A new abstract method `DisposeTensor` has been added to `Dataset<>`s.
24+
- This method will be used by `DataLoader`, to dispose the values provided by `GetTensor`.
25+
- `Dataset` and `IterableDataset` has implemented this method, so if your dataset is inherited from them, please check whether the disposal should be avoided in your case.
26+
27+
__Bug Fixes__:
28+
29+
- `TensorDataset` will now keep the aliases detached from dispose scopes, to avoid the unexpected disposal.
2130

2231
# NuGet Version 0.102.4
2332

src/TorchAudio/Datasets/SpeechCommandsDataset.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ internal SpeechCommandsDataset(string path, string subset)
4848

4949
public override long Count => _walker.LongLength;
5050

51+
public override void DisposeTensor(SpeechCommandsDatasetItem tensor)
52+
{
53+
tensor.waveform.Dispose();
54+
}
55+
5156
public override SpeechCommandsDatasetItem GetTensor(long index)
5257
{
5358
var audioPath = _walker[index];

src/TorchAudio/Datasets/YesnoDataset.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ public YesnoDataset(string directoryPathInArchive)
2222

2323
public override long Count => audioPathList.LongLength;
2424

25+
public override void DisposeTensor(YesnoDatasetItem tensor)
26+
{
27+
tensor.waveform.Dispose();
28+
}
29+
2530
public override YesnoDatasetItem GetTensor(long index)
2631
{
2732
var (waveform, sample_rate) = torchaudio.load(audioPathList[index]);

src/TorchSharp/DataLoader.cs

Lines changed: 73 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,16 @@ public DataLoader(Dataset dataset, int batchSize, bool shuffle = false, Device d
8989

9090
private static Dictionary<string, torch.Tensor> Collate(IEnumerable<Dictionary<string, torch.Tensor>> dic, torch.Device device)
9191
{
92-
Dictionary<string, torch.Tensor> batch = new();
93-
foreach (var x in dic.First().Keys) {
94-
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
95-
if (t.device_type != device.type || t.device_index != device.index)
96-
t = t.to(device);
97-
batch[x] = t;
92+
using (torch.NewDisposeScope()) {
93+
Dictionary<string, torch.Tensor> batch = new();
94+
foreach (var x in dic.First().Keys) {
95+
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
96+
if (t.device_type != device.type || t.device_index != device.index)
97+
t = t.to(device);
98+
batch[x] = t.MoveToOuterDisposeScope();
99+
}
100+
return batch;
98101
}
99-
return batch;
100102
}
101103
}
102104

@@ -143,14 +145,16 @@ public IterableDataLoader(IterableDataset dataset, int batchSize, bool shuffle =
143145

144146
private static IList<torch.Tensor> Collate(IEnumerable<IList<torch.Tensor>> dic, torch.Device device)
145147
{
146-
List<torch.Tensor> batch = new();
147-
for (var x = 0; x < dic.First().Count; x++) {
148-
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
149-
if (t.device_type != device.type || t.device_index != device.index)
150-
t = t.to(device);
151-
batch.Add(t);
148+
using (torch.NewDisposeScope()) {
149+
List<torch.Tensor> batch = new();
150+
for (var x = 0; x < dic.First().Count; x++) {
151+
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
152+
if (t.device_type != device.type || t.device_index != device.index)
153+
t = t.to(device);
154+
batch.Add(t.MoveToOuterDisposeScope());
155+
}
156+
return batch;
152157
}
153-
return batch;
154158
}
155159
}
156160

@@ -167,6 +171,7 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
167171
private IEnumerable<long> shuffler;
168172
private int num_worker;
169173
private Func<IEnumerable<T>, torch.Device, S> collate_fn;
174+
private bool autoDispose;
170175

171176
/// <summary>
172177
/// Pytorch style dataloader
@@ -181,7 +186,18 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
181186
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
182187
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
183188
/// </param>
184-
public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.Device, S> collate_fn, IEnumerable<long> shuffler, Device device = null, int num_worker = 1, bool drop_last = false)
189+
/// <param name="autoDispose">
190+
/// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
191+
/// </param>
192+
public DataLoader(
193+
Dataset<T> dataset,
194+
int batchSize,
195+
Func<IEnumerable<T>, torch.Device, S> collate_fn,
196+
IEnumerable<long> shuffler,
197+
Device device = null,
198+
int num_worker = 1,
199+
bool drop_last = false,
200+
bool autoDispose = true)
185201
{
186202
this.dataset = dataset;
187203
this.batchSize = batchSize;
@@ -191,6 +207,7 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
191207
this.shuffler = shuffler;
192208
this.num_worker = num_worker;
193209
this.collate_fn = collate_fn;
210+
this.autoDispose = autoDispose;
194211
}
195212

196213
/// <summary>
@@ -207,7 +224,19 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
207224
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
208225
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
209226
/// </param>
210-
public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.Device, S> collate_fn, bool shuffle = false, Device device = null, int? seed = null, int num_worker = 1, bool drop_last = false)
227+
/// <param name="autoDispose">
228+
/// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
229+
/// </param>
230+
public DataLoader(
231+
Dataset<T> dataset,
232+
int batchSize,
233+
Func<IEnumerable<T>, torch.Device, S> collate_fn,
234+
bool shuffle = false,
235+
Device device = null,
236+
int? seed = null,
237+
int num_worker = 1,
238+
bool drop_last = false,
239+
bool autoDispose = true)
211240
{
212241
this.dataset = dataset;
213242
this.batchSize = batchSize;
@@ -217,14 +246,17 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
217246
this.shuffler = seed is null ? new FisherYatesShuffler(dataset.Count) : new FisherYatesShuffler(dataset.Count, seed);
218247
this.num_worker = num_worker;
219248
this.collate_fn = collate_fn;
249+
this.autoDispose = autoDispose;
220250
}
221251

222252
/// <summary>
223253
/// Generate enumerator
224254
/// </summary>
225255
/// <returns>Enumerator for batch</returns>
226256
public IEnumerator<S> GetEnumerator() =>
227-
new DataLoaderEnumerator(dataset, batchSize, shuffle, device, shuffler, num_worker, collate_fn);
257+
new DataLoaderEnumerator(
258+
dataset, batchSize, shuffle, device,
259+
shuffler, num_worker, collate_fn, autoDispose);
228260

229261
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
230262

@@ -243,9 +275,17 @@ private class DataLoaderEnumerator : IEnumerator<S>
243275
private IEnumerator<long> shuffler;
244276
private long currentVal = 0;
245277
private int num_worker = 0;
246-
private IList<IDisposable> currentDisposables;
278+
private List<IDisposable> currentDisposables;
247279
private Func<IEnumerable<T>, torch.Device, S> collate_fn;
248-
public DataLoaderEnumerator(Dataset<T> dataset, int batchSize, bool shuffle, Device device, IEnumerable<long> shuffleEnumerable, int num_worker, Func<IEnumerable<T>, torch.Device, S> collate_fn)
280+
public DataLoaderEnumerator(
281+
Dataset<T> dataset,
282+
int batchSize,
283+
bool shuffle,
284+
Device device,
285+
IEnumerable<long> shuffleEnumerable,
286+
int num_worker,
287+
Func<IEnumerable<T>, torch.Device, S> collate_fn,
288+
bool autoDispose)
249289
{
250290
this.dataset = dataset;
251291
this.batchSize = batchSize;
@@ -255,6 +295,7 @@ public DataLoaderEnumerator(Dataset<T> dataset, int batchSize, bool shuffle, Dev
255295
if (num_worker < 1) num_worker = 1;
256296
this.num_worker = num_worker;
257297
this.collate_fn = collate_fn;
298+
this.currentDisposables = autoDispose ? new List<IDisposable>() : null;
258299
Reset();
259300
}
260301

@@ -300,10 +341,19 @@ public bool MoveNext()
300341
foreach (var task in tasks)
301342
task.Wait();
302343

303-
using (var collate_scope = DisposeScopeManager.NewDisposeScope()) {
344+
if (this.currentDisposables is not null) {
345+
using (var collate_scope = DisposeScopeManager.NewDisposeScope()) {
346+
Current = collate_fn(items, device);
347+
currentDisposables.AddRange(collate_scope.DisposablesView);
348+
collate_scope.Detach(currentDisposables);
349+
}
350+
}
351+
else {
304352
Current = collate_fn(items, device);
305-
currentDisposables = collate_scope.DisposablesView.ToList();
306-
collate_scope.Detach(currentDisposables);
353+
}
354+
355+
foreach (var item in items) {
356+
dataset.DisposeTensor(item);
307357
}
308358

309359
return true;
@@ -354,7 +404,7 @@ private void DisposeCurrent()
354404
if (currentDisposables is null) return;
355405
foreach (var x in currentDisposables)
356406
x.Dispose();
357-
currentDisposables = null;
407+
currentDisposables.Clear();
358408
shuffler?.Dispose();
359409
}
360410
}

src/TorchSharp/Dataset.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,25 @@ public static partial class data
1515
/// </summary>
1616
public abstract class Dataset : Dataset<Dictionary<string, torch.Tensor>>
1717
{
18+
public override void DisposeTensor(Dictionary<string, Tensor> tensor)
19+
{
20+
foreach (var t in tensor.Values) {
21+
t.Dispose();
22+
}
23+
}
1824
}
1925

2026
/// <summary>
2127
/// Iterable-style data sets
2228
/// </summary>
2329
public abstract class IterableDataset : Dataset<IList<Tensor>>
2430
{
31+
public override void DisposeTensor(IList<Tensor> tensor)
32+
{
33+
foreach (var t in tensor) {
34+
t.Dispose();
35+
}
36+
}
2537
}
2638

2739
/// <summary>
@@ -47,6 +59,8 @@ public void Dispose()
4759
/// <returns>Tensors of index. DataLoader will catenate these tensors into batches.</returns>
4860
public abstract T GetTensor(long index);
4961

62+
public abstract void DisposeTensor(T tensor);
63+
5064
protected virtual void Dispose(bool disposing)
5165
{
5266
}

src/TorchSharp/Utils/TensorDataset.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ internal TensorDataset(torch.Tensor[] tensors)
3737
long size1 = tensors[0].shape[0];
3838
if (!tensors.All(t => t.shape[0] == size1)) throw new ArgumentException("All tensors must have the same first dimension size.", nameof(tensors));
3939

40-
_tensors.AddRange(tensors);
40+
_tensors = tensors.Select(x => x.alias().DetachFromDisposeScope()).ToArray();
4141
}
4242

4343
/// <summary>
@@ -62,8 +62,16 @@ public override long Count {
6262
return this[index];
6363
}
6464

65-
private List<torch.Tensor> _tensors = new List<torch.Tensor>();
66-
}
65+
private torch.Tensor[] _tensors;
6766

67+
protected override void Dispose(bool disposing)
68+
{
69+
if (disposing) {
70+
foreach (var tensor in _tensors) {
71+
tensor.Dispose();
72+
}
73+
}
74+
}
75+
}
6876
}
6977
}

0 commit comments

Comments
 (0)