Skip to content

Commit 9bd4c38

Browse files
committed
auto dispose
1 parent e5385b7 commit 9bd4c38

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

src/TorchSharp/DataLoader.cs

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
171171
private IEnumerable<long> shuffler;
172172
private int num_worker;
173173
private Func<IEnumerable<T>, torch.Device, S> collate_fn;
174+
private bool autoDispose;
174175

175176
/// <summary>
176177
/// Pytorch style dataloader
@@ -185,7 +186,18 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
185186
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
186187
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
187188
/// </param>
188-
public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.Device, S> collate_fn, IEnumerable<long> shuffler, Device device = null, int num_worker = 1, bool drop_last = false)
189+
/// <param name="autoDispose">
190+
/// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
191+
/// </param>
192+
public DataLoader(
193+
Dataset<T> dataset,
194+
int batchSize,
195+
Func<IEnumerable<T>, torch.Device, S> collate_fn,
196+
IEnumerable<long> shuffler,
197+
Device device = null,
198+
int num_worker = 1,
199+
bool drop_last = false,
200+
bool autoDispose = true)
189201
{
190202
this.dataset = dataset;
191203
this.batchSize = batchSize;
@@ -195,6 +207,7 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
195207
this.shuffler = shuffler;
196208
this.num_worker = num_worker;
197209
this.collate_fn = collate_fn;
210+
this.autoDispose = autoDispose;
198211
}
199212

200213
/// <summary>
@@ -211,7 +224,19 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
211224
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
212225
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
213226
/// </param>
214-
public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.Device, S> collate_fn, bool shuffle = false, Device device = null, int? seed = null, int num_worker = 1, bool drop_last = false)
227+
/// <param name="autoDispose">
228+
/// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
229+
/// </param>
230+
public DataLoader(
231+
Dataset<T> dataset,
232+
int batchSize,
233+
Func<IEnumerable<T>, torch.Device, S> collate_fn,
234+
bool shuffle = false,
235+
Device device = null,
236+
int? seed = null,
237+
int num_worker = 1,
238+
bool drop_last = false,
239+
bool autoDispose = true)
215240
{
216241
this.dataset = dataset;
217242
this.batchSize = batchSize;
@@ -221,14 +246,17 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
221246
this.shuffler = seed is null ? new FisherYatesShuffler(dataset.Count) : new FisherYatesShuffler(dataset.Count, seed);
222247
this.num_worker = num_worker;
223248
this.collate_fn = collate_fn;
249+
this.autoDispose = autoDispose;
224250
}
225251

226252
/// <summary>
227253
/// Generate enumerator
228254
/// </summary>
229255
/// <returns>Enumerator for batch</returns>
230256
public IEnumerator<S> GetEnumerator() =>
231-
new DataLoaderEnumerator(dataset, batchSize, shuffle, device, shuffler, num_worker, collate_fn);
257+
new DataLoaderEnumerator(
258+
dataset, batchSize, shuffle, device,
259+
shuffler, num_worker, collate_fn, autoDispose);
232260

233261
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
234262

@@ -247,9 +275,17 @@ private class DataLoaderEnumerator : IEnumerator<S>
247275
private IEnumerator<long> shuffler;
248276
private long currentVal = 0;
249277
private int num_worker = 0;
250-
private IList<IDisposable> currentDisposables;
278+
private List<IDisposable> currentDisposables;
251279
private Func<IEnumerable<T>, torch.Device, S> collate_fn;
252-
public DataLoaderEnumerator(Dataset<T> dataset, int batchSize, bool shuffle, Device device, IEnumerable<long> shuffleEnumerable, int num_worker, Func<IEnumerable<T>, torch.Device, S> collate_fn)
280+
public DataLoaderEnumerator(
281+
Dataset<T> dataset,
282+
int batchSize,
283+
bool shuffle,
284+
Device device,
285+
IEnumerable<long> shuffleEnumerable,
286+
int num_worker,
287+
Func<IEnumerable<T>, torch.Device, S> collate_fn,
288+
bool autoDispose)
253289
{
254290
this.dataset = dataset;
255291
this.batchSize = batchSize;
@@ -259,6 +295,7 @@ public DataLoaderEnumerator(Dataset<T> dataset, int batchSize, bool shuffle, Dev
259295
if (num_worker < 1) num_worker = 1;
260296
this.num_worker = num_worker;
261297
this.collate_fn = collate_fn;
298+
this.currentDisposables = autoDispose ? new List<IDisposable>() : null;
262299
Reset();
263300
}
264301

@@ -304,10 +341,15 @@ public bool MoveNext()
304341
foreach (var task in tasks)
305342
task.Wait();
306343

307-
using (var collate_scope = DisposeScopeManager.NewDisposeScope()) {
344+
if (this.currentDisposables is not null) {
345+
using (var collate_scope = DisposeScopeManager.NewDisposeScope()) {
346+
Current = collate_fn(items, device);
347+
currentDisposables = collate_scope.DisposablesView.ToList();
348+
collate_scope.Detach(currentDisposables);
349+
}
350+
}
351+
else {
308352
Current = collate_fn(items, device);
309-
currentDisposables = collate_scope.DisposablesView.ToList();
310-
collate_scope.Detach(currentDisposables);
311353
}
312354

313355
return true;
@@ -358,7 +400,7 @@ private void DisposeCurrent()
358400
if (currentDisposables is null) return;
359401
foreach (var x in currentDisposables)
360402
x.Dispose();
361-
currentDisposables = null;
403+
currentDisposables.Clear();
362404
shuffler?.Dispose();
363405
}
364406
}

0 commit comments

Comments
 (0)