@@ -171,6 +171,7 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
171
171
private IEnumerable < long > shuffler ;
172
172
private int num_worker ;
173
173
private Func < IEnumerable < T > , torch . Device , S > collate_fn ;
174
+ private bool autoDispose ;
174
175
175
176
/// <summary>
176
177
/// Pytorch style dataloader
@@ -185,7 +186,18 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
185
186
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
186
187
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
187
188
/// </param>
188
- public DataLoader ( Dataset < T > dataset , int batchSize , Func < IEnumerable < T > , torch . Device , S > collate_fn , IEnumerable < long > shuffler , Device device = null , int num_worker = 1 , bool drop_last = false )
189
+ /// <param name="autoDispose">
190
+ /// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
191
+ /// </param>
192
+ public DataLoader (
193
+ Dataset < T > dataset ,
194
+ int batchSize ,
195
+ Func < IEnumerable < T > , torch . Device , S > collate_fn ,
196
+ IEnumerable < long > shuffler ,
197
+ Device device = null ,
198
+ int num_worker = 1 ,
199
+ bool drop_last = false ,
200
+ bool autoDispose = true )
189
201
{
190
202
this . dataset = dataset ;
191
203
this . batchSize = batchSize ;
@@ -195,6 +207,7 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
195
207
this . shuffler = shuffler ;
196
208
this . num_worker = num_worker ;
197
209
this . collate_fn = collate_fn ;
210
+ this . autoDispose = autoDispose ;
198
211
}
199
212
200
213
/// <summary>
@@ -211,7 +224,19 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
211
224
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
212
225
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
213
226
/// </param>
214
- public DataLoader ( Dataset < T > dataset , int batchSize , Func < IEnumerable < T > , torch . Device , S > collate_fn , bool shuffle = false , Device device = null , int ? seed = null , int num_worker = 1 , bool drop_last = false )
227
+ /// <param name="autoDispose">
228
+ /// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
229
+ /// </param>
230
+ public DataLoader (
231
+ Dataset < T > dataset ,
232
+ int batchSize ,
233
+ Func < IEnumerable < T > , torch . Device , S > collate_fn ,
234
+ bool shuffle = false ,
235
+ Device device = null ,
236
+ int ? seed = null ,
237
+ int num_worker = 1 ,
238
+ bool drop_last = false ,
239
+ bool autoDispose = true )
215
240
{
216
241
this . dataset = dataset ;
217
242
this . batchSize = batchSize ;
@@ -221,14 +246,17 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
221
246
this . shuffler = seed is null ? new FisherYatesShuffler ( dataset . Count ) : new FisherYatesShuffler ( dataset . Count , seed ) ;
222
247
this . num_worker = num_worker ;
223
248
this . collate_fn = collate_fn ;
249
+ this . autoDispose = autoDispose ;
224
250
}
225
251
226
252
/// <summary>
227
253
/// Generate enumerator
228
254
/// </summary>
229
255
/// <returns>Enumerator for batch</returns>
230
256
public IEnumerator < S > GetEnumerator ( ) =>
231
- new DataLoaderEnumerator ( dataset , batchSize , shuffle , device , shuffler , num_worker , collate_fn ) ;
257
+ new DataLoaderEnumerator (
258
+ dataset , batchSize , shuffle , device ,
259
+ shuffler , num_worker , collate_fn , autoDispose ) ;
232
260
233
261
IEnumerator IEnumerable . GetEnumerator ( ) => GetEnumerator ( ) ;
234
262
@@ -247,9 +275,17 @@ private class DataLoaderEnumerator : IEnumerator<S>
247
275
private IEnumerator < long > shuffler ;
248
276
private long currentVal = 0 ;
249
277
private int num_worker = 0 ;
250
- private IList < IDisposable > currentDisposables ;
278
+ private List < IDisposable > currentDisposables ;
251
279
private Func < IEnumerable < T > , torch . Device , S > collate_fn ;
252
- public DataLoaderEnumerator ( Dataset < T > dataset , int batchSize , bool shuffle , Device device , IEnumerable < long > shuffleEnumerable , int num_worker , Func < IEnumerable < T > , torch . Device , S > collate_fn )
280
+ public DataLoaderEnumerator (
281
+ Dataset < T > dataset ,
282
+ int batchSize ,
283
+ bool shuffle ,
284
+ Device device ,
285
+ IEnumerable < long > shuffleEnumerable ,
286
+ int num_worker ,
287
+ Func < IEnumerable < T > , torch . Device , S > collate_fn ,
288
+ bool autoDispose )
253
289
{
254
290
this . dataset = dataset ;
255
291
this . batchSize = batchSize ;
@@ -259,6 +295,7 @@ public DataLoaderEnumerator(Dataset<T> dataset, int batchSize, bool shuffle, Dev
259
295
if ( num_worker < 1 ) num_worker = 1 ;
260
296
this . num_worker = num_worker ;
261
297
this . collate_fn = collate_fn ;
298
+ this . currentDisposables = autoDispose ? new List < IDisposable > ( ) : null ;
262
299
Reset ( ) ;
263
300
}
264
301
@@ -304,10 +341,15 @@ public bool MoveNext()
304
341
foreach ( var task in tasks )
305
342
task . Wait ( ) ;
306
343
307
- using ( var collate_scope = DisposeScopeManager . NewDisposeScope ( ) ) {
344
+ if ( this . currentDisposables is not null ) {
345
+ using ( var collate_scope = DisposeScopeManager . NewDisposeScope ( ) ) {
346
+ Current = collate_fn ( items , device ) ;
347
+ currentDisposables = collate_scope . DisposablesView . ToList ( ) ;
348
+ collate_scope . Detach ( currentDisposables ) ;
349
+ }
350
+ }
351
+ else {
308
352
Current = collate_fn ( items , device ) ;
309
- currentDisposables = collate_scope . DisposablesView . ToList ( ) ;
310
- collate_scope . Detach ( currentDisposables ) ;
311
353
}
312
354
313
355
return true ;
@@ -358,7 +400,7 @@ private void DisposeCurrent()
358
400
if ( currentDisposables is null ) return ;
359
401
foreach ( var x in currentDisposables )
360
402
x . Dispose ( ) ;
361
- currentDisposables = null ;
403
+ currentDisposables . Clear ( ) ;
362
404
shuffler ? . Dispose ( ) ;
363
405
}
364
406
}
0 commit comments