Skip to content

Commit e220b28

Browse files
Merge pull request #1306 from yueyinqiu/collate
tensor disposal in data loader
2 parents 8b14055 + 2a9344a commit e220b28

File tree

3 files changed

+150
-128
lines changed

3 files changed

+150
-128
lines changed

RELEASENOTES.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ __Bug Fixes__:
1818

1919
- #1300 `Adadelta`, `Adam` and `AdamW` will no longer throw `NullReferenceException` when `maximize` is `true` and `grad` is `null`.
2020
- `torch.normal` will now correctly return a leaf tensor.
21+
- New options `disposeBatch` and `disposeDataset` have been added into `DataLoader`.
22+
- The default collate functions will now always dispose the intermediate tensors, rather than wait for the next iteration.
23+
24+
__Bug Fixes__:
25+
26+
- `TensorDataset` will now keep the aliases detached from dispose scopes, to avoid the unexpected disposal.
27+
- `DataLoaderEnumerator` has been completely rewritten to resolve the unexpected shuffler disposal, the ignorance of drop last and the incorrect count of worker.
2128

2229
# NuGet Version 0.102.4
2330

src/TorchSharp/DataLoader.cs

Lines changed: 132 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,16 @@ public DataLoader(Dataset dataset, int batchSize, bool shuffle = false, Device d
8989

9090
private static Dictionary<string, torch.Tensor> Collate(IEnumerable<Dictionary<string, torch.Tensor>> dic, torch.Device device)
9191
{
92-
Dictionary<string, torch.Tensor> batch = new();
93-
foreach (var x in dic.First().Keys) {
94-
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
95-
if (t.device_type != device.type || t.device_index != device.index)
96-
t = t.to(device);
97-
batch[x] = t;
92+
using (torch.NewDisposeScope()) {
93+
Dictionary<string, torch.Tensor> batch = new();
94+
foreach (var x in dic.First().Keys) {
95+
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
96+
if (t.device_type != device.type || t.device_index != device.index)
97+
t = t.to(device);
98+
batch[x] = t.MoveToOuterDisposeScope();
99+
}
100+
return batch;
98101
}
99-
return batch;
100102
}
101103
}
102104

@@ -143,30 +145,34 @@ public IterableDataLoader(IterableDataset dataset, int batchSize, bool shuffle =
143145

144146
private static IList<torch.Tensor> Collate(IEnumerable<IList<torch.Tensor>> dic, torch.Device device)
145147
{
146-
List<torch.Tensor> batch = new();
147-
for (var x = 0; x < dic.First().Count; x++) {
148-
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
149-
if (t.device_type != device.type || t.device_index != device.index)
150-
t = t.to(device);
151-
batch.Add(t);
148+
using (torch.NewDisposeScope()) {
149+
List<torch.Tensor> batch = new();
150+
for (var x = 0; x < dic.First().Count; x++) {
151+
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
152+
if (t.device_type != device.type || t.device_index != device.index)
153+
t = t.to(device);
154+
batch.Add(t.MoveToOuterDisposeScope());
155+
}
156+
return batch;
152157
}
153-
return batch;
154158
}
155159
}
156160

161+
#nullable enable
157162
/// <summary>
158163
/// This class supports creating batches from data sets.
159164
/// </summary>
160165
public class DataLoader<T, S> : IEnumerable<S>, IDisposable
161166
{
162-
private Dataset<T> dataset;
163-
private int batchSize;
164-
private bool shuffle;
165-
private bool drop_last;
166-
private Device device;
167-
private IEnumerable<long> shuffler;
168-
private int num_worker;
169-
private Func<IEnumerable<T>, torch.Device, S> collate_fn;
167+
private readonly Dataset<T> dataset;
168+
private readonly int batchSize;
169+
private readonly bool drop_last;
170+
private readonly Device device;
171+
private readonly IEnumerable<long> shuffler;
172+
private readonly int num_worker;
173+
private readonly Func<IEnumerable<T>, torch.Device, S> collate_fn;
174+
private readonly bool disposeBatch;
175+
private readonly bool disposeDataset;
170176

171177
/// <summary>
172178
/// Pytorch style dataloader
@@ -181,16 +187,32 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
181187
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
182188
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
183189
/// </param>
184-
public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.Device, S> collate_fn, IEnumerable<long> shuffler, Device device = null, int num_worker = 1, bool drop_last = false)
190+
/// <param name="disposeBatch">
191+
/// Indicates whether to automatically dispose the collated tensors after an iteration.
192+
/// </param>
193+
/// <param name="disposeDataset">
194+
/// Indicates whether to dispose the dataset when being disposed.
195+
/// </param>
196+
public DataLoader(
197+
Dataset<T> dataset,
198+
int batchSize,
199+
Func<IEnumerable<T>, torch.Device, S> collate_fn,
200+
IEnumerable<long> shuffler,
201+
Device? device = null,
202+
int num_worker = 1,
203+
bool drop_last = false,
204+
bool disposeBatch = true,
205+
bool disposeDataset = true)
185206
{
186207
this.dataset = dataset;
187208
this.batchSize = batchSize;
188-
this.shuffle = true;
189209
this.drop_last = drop_last;
190210
this.device = device ?? CPU;
191211
this.shuffler = shuffler;
192-
this.num_worker = num_worker;
212+
this.num_worker = Math.Max(num_worker, 1);
193213
this.collate_fn = collate_fn;
214+
this.disposeBatch = disposeBatch;
215+
this.disposeDataset = disposeDataset;
194216
}
195217

196218
/// <summary>
@@ -207,24 +229,39 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
207229
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
208230
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
209231
/// </param>
210-
public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.Device, S> collate_fn, bool shuffle = false, Device device = null, int? seed = null, int num_worker = 1, bool drop_last = false)
232+
/// <param name="disposeBatch">
233+
/// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
234+
/// </param>
235+
/// <param name="disposeDataset">
236+
/// Indicates whether to dispose the dataset when being disposed.
237+
/// </param>
238+
public DataLoader(
239+
Dataset<T> dataset,
240+
int batchSize,
241+
Func<IEnumerable<T>, torch.Device, S> collate_fn,
242+
bool shuffle = false,
243+
Device? device = null,
244+
int? seed = null,
245+
int num_worker = 1,
246+
bool drop_last = false,
247+
bool disposeBatch = true,
248+
bool disposeDataset = true) :
249+
this(dataset, batchSize, collate_fn,
250+
shuffle ? new FisherYatesShuffler(dataset.Count, seed) : LongRange(dataset.Count),
251+
device, num_worker, drop_last, disposeBatch, disposeDataset)
252+
{ }
253+
254+
static IEnumerable<long> LongRange(long count)
211255
{
212-
this.dataset = dataset;
213-
this.batchSize = batchSize;
214-
this.shuffle = shuffle;
215-
this.drop_last = drop_last;
216-
this.device = device ?? CPU;
217-
this.shuffler = seed is null ? new FisherYatesShuffler(dataset.Count) : new FisherYatesShuffler(dataset.Count, seed);
218-
this.num_worker = num_worker;
219-
this.collate_fn = collate_fn;
256+
for (long i = 0; i < count; i++)
257+
yield return i;
220258
}
221259

222260
/// <summary>
223261
/// Generate enumerator
224262
/// </summary>
225263
/// <returns>Enumerator for batch</returns>
226-
public IEnumerator<S> GetEnumerator() =>
227-
new DataLoaderEnumerator(dataset, batchSize, shuffle, device, shuffler, num_worker, collate_fn);
264+
public IEnumerator<S> GetEnumerator() => new DataLoaderEnumerator(this);
228265

229266
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
230267

@@ -233,41 +270,39 @@ public IEnumerator<S> GetEnumerator() =>
233270
/// </summary>
234271
public long Count => drop_last ? (dataset.Count / batchSize) : ((dataset.Count - 1) / batchSize + 1);
235272

236-
private class DataLoaderEnumerator : IEnumerator<S>
273+
public void Dispose()
274+
{
275+
Dispose(true);
276+
GC.SuppressFinalize(this);
277+
}
278+
279+
protected virtual void Dispose(bool disposing)
280+
{
281+
if (disposing && disposeDataset) {
282+
dataset.Dispose();
283+
}
284+
}
285+
286+
sealed class DataLoaderEnumerator : IEnumerator<S>
237287
{
238-
private Dataset<T> dataset;
239-
private int batchSize;
240-
private Device device;
241-
private bool shuffle;
242-
private IEnumerable<long> shuffleEnumerable;
288+
private readonly DataLoader<T, S> loader;
243289
private IEnumerator<long> shuffler;
244-
private long currentVal = 0;
245-
private int num_worker = 0;
246-
private IList<IDisposable> currentDisposables;
247-
private Func<IEnumerable<T>, torch.Device, S> collate_fn;
248-
public DataLoaderEnumerator(Dataset<T> dataset, int batchSize, bool shuffle, Device device, IEnumerable<long> shuffleEnumerable, int num_worker, Func<IEnumerable<T>, torch.Device, S> collate_fn)
290+
private IReadOnlyList<IDisposable> currentDisposables;
291+
public DataLoaderEnumerator(DataLoader<T, S> loader)
249292
{
250-
this.dataset = dataset;
251-
this.batchSize = batchSize;
252-
this.device = device;
253-
this.shuffle = shuffle;
254-
this.shuffleEnumerable = shuffleEnumerable;
255-
if (num_worker < 1) num_worker = 1;
256-
this.num_worker = num_worker;
257-
this.collate_fn = collate_fn;
293+
this.loader = loader;
294+
this.currentDisposables = Array.Empty<IDisposable>();
295+
// TODO: Use MemberNotNull instead.
296+
shuffler = null!;
258297
Reset();
259298
}
260299

261-
private bool MoveNextValue()
300+
private long? MoveNextValue()
262301
{
263-
if (shuffle) {
264-
if (!shuffler.MoveNext()) return false;
265-
currentVal = shuffler.Current;
266-
return true;
267-
} else {
268-
currentVal++;
269-
return currentVal < dataset.Count;
302+
if (!shuffler.MoveNext()) {
303+
return null;
270304
}
305+
return shuffler.Current;
271306
}
272307

273308
/// <summary>
@@ -277,53 +312,38 @@ private bool MoveNextValue()
277312
public bool MoveNext()
278313
{
279314
DisposeCurrent();
280-
using (var scope = DisposeScopeManager.NewDisposeScope()) {
281-
if (!MoveNextValue()) return false;
282315

283-
var tensorIndexList = new List<long> { currentVal };
284-
for (int i = 1; i < batchSize; i++) {
285-
if (!MoveNextValue()) break;
286-
tensorIndexList.Add(currentVal);
316+
using (var scope = torch.NewDisposeScope()) {
317+
var indices = Enumerable.Range(0, loader.batchSize)
318+
.Select(_ => MoveNextValue())
319+
.Where(x => x.HasValue)
320+
.Cast<long>()
321+
.ToArray();
322+
if (indices.Length is 0)
323+
return false;
324+
if (loader.drop_last && indices.Length < loader.batchSize) {
325+
return false;
287326
}
288327

289-
var items = new List<T>(new T[tensorIndexList.Count]);
290-
var taskedBatchCount = 0;
291-
292-
//Run Async
293-
var tasks = new List<Task>();
294-
foreach (var _ in Enumerable.Range(1, num_worker - 1))
295-
tasks.Add(new(ProcessPendingBatches));
296-
tasks.ForEach(x => x.Start());
297-
298-
ProcessPendingBatches();
299-
300-
foreach (var task in tasks)
301-
task.Wait();
302-
303-
using (var collate_scope = DisposeScopeManager.NewDisposeScope()) {
304-
Current = collate_fn(items, device);
305-
currentDisposables = collate_scope.DisposablesView.ToList();
306-
collate_scope.Detach(currentDisposables);
328+
var tensors = new T[indices.Length];
329+
Enumerable.Range(0, indices.Length)
330+
.AsParallel()
331+
.WithDegreeOfParallelism(loader.num_worker)
332+
.ForAll((i) => {
333+
tensors[i] = loader.dataset.GetTensor(indices[i]);
334+
});
335+
336+
using var collate_scope = DisposeScopeManager.NewDisposeScope();
337+
current = loader.collate_fn(tensors, loader.device);
338+
339+
// TODO: Will be better if we have something like DetachAll
340+
var view = collate_scope.DisposablesView;
341+
collate_scope.Detach(view);
342+
if (loader.disposeBatch) {
343+
this.currentDisposables = view;
307344
}
308345

309346
return true;
310-
311-
void ProcessPendingBatches()
312-
{
313-
while (true) {
314-
var idx = ScheduleBatch();
315-
if (idx is null) break;
316-
items[idx.Value.Item1] = dataset.GetTensor(idx.Value.Item2);
317-
}
318-
}
319-
320-
(int, long)? ScheduleBatch()
321-
{
322-
var t = Interlocked.Increment(ref taskedBatchCount) - 1;
323-
if (t < tensorIndexList.Count)
324-
return (t, tensorIndexList[t]);
325-
return null;
326-
}
327347
}
328348
}
329349

@@ -333,42 +353,29 @@ void ProcessPendingBatches()
333353
public void Reset()
334354
{
335355
DisposeCurrent();
336-
if (shuffle) shuffler = shuffleEnumerable.GetEnumerator();
337-
currentVal = -1;
356+
shuffler?.Dispose();
357+
shuffler = loader.shuffler.GetEnumerator();
338358
}
339359

360+
S? current;
340361
/// <summary>
341362
/// Current tensor
342363
/// </summary>
343-
public S Current { get; private set; }
364+
public S Current => current!;
344365

345-
object IEnumerator.Current => Current;
366+
object IEnumerator.Current => current!;
346367

347368
public void Dispose()
348369
{
370+
shuffler.Dispose();
349371
DisposeCurrent();
350372
}
351373

352374
private void DisposeCurrent()
353375
{
354-
if (currentDisposables is null) return;
355-
foreach (var x in currentDisposables)
376+
foreach (var x in this.currentDisposables)
356377
x.Dispose();
357-
currentDisposables = null;
358-
shuffler?.Dispose();
359-
}
360-
}
361-
362-
public void Dispose()
363-
{
364-
Dispose(true);
365-
GC.SuppressFinalize(this);
366-
}
367-
368-
protected virtual void Dispose(bool disposing)
369-
{
370-
if (disposing) {
371-
dataset.Dispose();
378+
this.currentDisposables = Array.Empty<IDisposable>();
372379
}
373380
}
374381
}

0 commit comments

Comments
 (0)