@@ -89,14 +89,16 @@ public DataLoader(Dataset dataset, int batchSize, bool shuffle = false, Device d
89
89
90
90
private static Dictionary < string , torch . Tensor > Collate ( IEnumerable < Dictionary < string , torch . Tensor > > dic , torch . Device device )
91
91
{
92
- Dictionary < string , torch . Tensor > batch = new ( ) ;
93
- foreach ( var x in dic . First ( ) . Keys ) {
94
- var t = cat ( dic . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
95
- if ( t . device_type != device . type || t . device_index != device . index )
96
- t = t . to ( device ) ;
97
- batch [ x ] = t ;
92
+ using ( torch . NewDisposeScope ( ) ) {
93
+ Dictionary < string , torch . Tensor > batch = new ( ) ;
94
+ foreach ( var x in dic . First ( ) . Keys ) {
95
+ var t = cat ( dic . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
96
+ if ( t . device_type != device . type || t . device_index != device . index )
97
+ t = t . to ( device ) ;
98
+ batch [ x ] = t . MoveToOuterDisposeScope ( ) ;
99
+ }
100
+ return batch ;
98
101
}
99
- return batch ;
100
102
}
101
103
}
102
104
@@ -143,30 +145,34 @@ public IterableDataLoader(IterableDataset dataset, int batchSize, bool shuffle =
143
145
144
146
private static IList < torch . Tensor > Collate ( IEnumerable < IList < torch . Tensor > > dic , torch . Device device )
145
147
{
146
- List < torch . Tensor > batch = new ( ) ;
147
- for ( var x = 0 ; x < dic . First ( ) . Count ; x ++ ) {
148
- var t = cat ( dic . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
149
- if ( t . device_type != device . type || t . device_index != device . index )
150
- t = t . to ( device ) ;
151
- batch . Add ( t ) ;
148
+ using ( torch . NewDisposeScope ( ) ) {
149
+ List < torch . Tensor > batch = new ( ) ;
150
+ for ( var x = 0 ; x < dic . First ( ) . Count ; x ++ ) {
151
+ var t = cat ( dic . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
152
+ if ( t . device_type != device . type || t . device_index != device . index )
153
+ t = t . to ( device ) ;
154
+ batch . Add ( t . MoveToOuterDisposeScope ( ) ) ;
155
+ }
156
+ return batch ;
152
157
}
153
- return batch ;
154
158
}
155
159
}
156
160
161
+ #nullable enable
157
162
/// <summary>
158
163
/// This class supports creating batches from data sets.
159
164
/// </summary>
160
165
public class DataLoader < T , S > : IEnumerable < S > , IDisposable
161
166
{
162
- private Dataset < T > dataset ;
163
- private int batchSize ;
164
- private bool shuffle ;
165
- private bool drop_last ;
166
- private Device device ;
167
- private IEnumerable < long > shuffler ;
168
- private int num_worker ;
169
- private Func < IEnumerable < T > , torch . Device , S > collate_fn ;
167
+ private readonly Dataset < T > dataset ;
168
+ private readonly int batchSize ;
169
+ private readonly bool drop_last ;
170
+ private readonly Device device ;
171
+ private readonly IEnumerable < long > shuffler ;
172
+ private readonly int num_worker ;
173
+ private readonly Func < IEnumerable < T > , torch . Device , S > collate_fn ;
174
+ private readonly bool disposeBatch ;
175
+ private readonly bool disposeDataset ;
170
176
171
177
/// <summary>
172
178
/// Pytorch style dataloader
@@ -181,16 +187,32 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
181
187
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
182
188
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
183
189
/// </param>
184
- public DataLoader ( Dataset < T > dataset , int batchSize , Func < IEnumerable < T > , torch . Device , S > collate_fn , IEnumerable < long > shuffler , Device device = null , int num_worker = 1 , bool drop_last = false )
190
+ /// <param name="disposeBatch">
191
+ /// Indicates whether to automatically dispose the collated tensors after an iteration.
192
+ /// </param>
193
+ /// <param name="disposeDataset">
194
+ /// Indicates whether to dispose the dataset when being disposed.
195
+ /// </param>
196
+ public DataLoader (
197
+ Dataset < T > dataset ,
198
+ int batchSize ,
199
+ Func < IEnumerable < T > , torch . Device , S > collate_fn ,
200
+ IEnumerable < long > shuffler ,
201
+ Device ? device = null ,
202
+ int num_worker = 1 ,
203
+ bool drop_last = false ,
204
+ bool disposeBatch = true ,
205
+ bool disposeDataset = true )
185
206
{
186
207
this . dataset = dataset ;
187
208
this . batchSize = batchSize ;
188
- this . shuffle = true ;
189
209
this . drop_last = drop_last ;
190
210
this . device = device ?? CPU ;
191
211
this . shuffler = shuffler ;
192
- this . num_worker = num_worker ;
212
+ this . num_worker = Math . Max ( num_worker , 1 ) ;
193
213
this . collate_fn = collate_fn ;
214
+ this . disposeBatch = disposeBatch ;
215
+ this . disposeDataset = disposeDataset ;
194
216
}
195
217
196
218
/// <summary>
@@ -207,24 +229,39 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
207
229
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
208
230
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
209
231
/// </param>
210
- public DataLoader ( Dataset < T > dataset , int batchSize , Func < IEnumerable < T > , torch . Device , S > collate_fn , bool shuffle = false , Device device = null , int ? seed = null , int num_worker = 1 , bool drop_last = false )
232
+ /// <param name="disposeBatch">
233
+ /// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
234
+ /// </param>
235
+ /// <param name="disposeDataset">
236
+ /// Indicates whether to dispose the dataset when being disposed.
237
+ /// </param>
238
+ public DataLoader (
239
+ Dataset < T > dataset ,
240
+ int batchSize ,
241
+ Func < IEnumerable < T > , torch . Device , S > collate_fn ,
242
+ bool shuffle = false ,
243
+ Device ? device = null ,
244
+ int ? seed = null ,
245
+ int num_worker = 1 ,
246
+ bool drop_last = false ,
247
+ bool disposeBatch = true ,
248
+ bool disposeDataset = true ) :
249
+ this ( dataset , batchSize , collate_fn ,
250
+ shuffle ? new FisherYatesShuffler ( dataset . Count , seed ) : LongRange ( dataset . Count ) ,
251
+ device , num_worker , drop_last , disposeBatch , disposeDataset )
252
+ { }
253
+
254
+ static IEnumerable < long > LongRange ( long count )
211
255
{
212
- this . dataset = dataset ;
213
- this . batchSize = batchSize ;
214
- this . shuffle = shuffle ;
215
- this . drop_last = drop_last ;
216
- this . device = device ?? CPU ;
217
- this . shuffler = seed is null ? new FisherYatesShuffler ( dataset . Count ) : new FisherYatesShuffler ( dataset . Count , seed ) ;
218
- this . num_worker = num_worker ;
219
- this . collate_fn = collate_fn ;
256
+ for ( long i = 0 ; i < count ; i ++ )
257
+ yield return i ;
220
258
}
221
259
222
260
/// <summary>
223
261
/// Generate enumerator
224
262
/// </summary>
225
263
/// <returns>Enumerator for batch</returns>
226
- public IEnumerator < S > GetEnumerator ( ) =>
227
- new DataLoaderEnumerator ( dataset , batchSize , shuffle , device , shuffler , num_worker , collate_fn ) ;
264
+ public IEnumerator < S > GetEnumerator ( ) => new DataLoaderEnumerator ( this ) ;
228
265
229
266
IEnumerator IEnumerable . GetEnumerator ( ) => GetEnumerator ( ) ;
230
267
@@ -233,41 +270,39 @@ public IEnumerator<S> GetEnumerator() =>
233
270
/// </summary>
234
271
public long Count => drop_last ? ( dataset . Count / batchSize ) : ( ( dataset . Count - 1 ) / batchSize + 1 ) ;
235
272
236
- private class DataLoaderEnumerator : IEnumerator < S >
273
+ public void Dispose ( )
274
+ {
275
+ Dispose ( true ) ;
276
+ GC . SuppressFinalize ( this ) ;
277
+ }
278
+
279
+ protected virtual void Dispose ( bool disposing )
280
+ {
281
+ if ( disposing && disposeDataset ) {
282
+ dataset . Dispose ( ) ;
283
+ }
284
+ }
285
+
286
+ sealed class DataLoaderEnumerator : IEnumerator < S >
237
287
{
238
- private Dataset < T > dataset ;
239
- private int batchSize ;
240
- private Device device ;
241
- private bool shuffle ;
242
- private IEnumerable < long > shuffleEnumerable ;
288
+ private readonly DataLoader < T , S > loader ;
243
289
private IEnumerator < long > shuffler ;
244
- private long currentVal = 0 ;
245
- private int num_worker = 0 ;
246
- private IList < IDisposable > currentDisposables ;
247
- private Func < IEnumerable < T > , torch . Device , S > collate_fn ;
248
- public DataLoaderEnumerator ( Dataset < T > dataset , int batchSize , bool shuffle , Device device , IEnumerable < long > shuffleEnumerable , int num_worker , Func < IEnumerable < T > , torch . Device , S > collate_fn )
290
+ private IReadOnlyList < IDisposable > currentDisposables ;
291
+ public DataLoaderEnumerator ( DataLoader < T , S > loader )
249
292
{
250
- this . dataset = dataset ;
251
- this . batchSize = batchSize ;
252
- this . device = device ;
253
- this . shuffle = shuffle ;
254
- this . shuffleEnumerable = shuffleEnumerable ;
255
- if ( num_worker < 1 ) num_worker = 1 ;
256
- this . num_worker = num_worker ;
257
- this . collate_fn = collate_fn ;
293
+ this . loader = loader ;
294
+ this . currentDisposables = Array . Empty < IDisposable > ( ) ;
295
+ // TODO: Use MemberNotNull instead.
296
+ shuffler = null ! ;
258
297
Reset ( ) ;
259
298
}
260
299
261
- private bool MoveNextValue ( )
300
+ private long ? MoveNextValue ( )
262
301
{
263
- if ( shuffle ) {
264
- if ( ! shuffler . MoveNext ( ) ) return false ;
265
- currentVal = shuffler . Current ;
266
- return true ;
267
- } else {
268
- currentVal ++ ;
269
- return currentVal < dataset . Count ;
302
+ if ( ! shuffler . MoveNext ( ) ) {
303
+ return null ;
270
304
}
305
+ return shuffler . Current ;
271
306
}
272
307
273
308
/// <summary>
@@ -277,53 +312,38 @@ private bool MoveNextValue()
277
312
public bool MoveNext ( )
278
313
{
279
314
DisposeCurrent ( ) ;
280
- using ( var scope = DisposeScopeManager . NewDisposeScope ( ) ) {
281
- if ( ! MoveNextValue ( ) ) return false ;
282
315
283
- var tensorIndexList = new List < long > { currentVal } ;
284
- for ( int i = 1 ; i < batchSize ; i ++ ) {
285
- if ( ! MoveNextValue ( ) ) break ;
286
- tensorIndexList . Add ( currentVal ) ;
316
+ using ( var scope = torch . NewDisposeScope ( ) ) {
317
+ var indices = Enumerable . Range ( 0 , loader . batchSize )
318
+ . Select ( _ => MoveNextValue ( ) )
319
+ . Where ( x => x . HasValue )
320
+ . Cast < long > ( )
321
+ . ToArray ( ) ;
322
+ if ( indices . Length is 0 )
323
+ return false ;
324
+ if ( loader . drop_last && indices . Length < loader . batchSize ) {
325
+ return false ;
287
326
}
288
327
289
- var items = new List < T > ( new T [ tensorIndexList . Count ] ) ;
290
- var taskedBatchCount = 0 ;
291
-
292
- //Run Async
293
- var tasks = new List < Task > ( ) ;
294
- foreach ( var _ in Enumerable . Range ( 1 , num_worker - 1 ) )
295
- tasks . Add ( new ( ProcessPendingBatches ) ) ;
296
- tasks . ForEach ( x => x . Start ( ) ) ;
297
-
298
- ProcessPendingBatches ( ) ;
299
-
300
- foreach ( var task in tasks )
301
- task . Wait ( ) ;
302
-
303
- using ( var collate_scope = DisposeScopeManager . NewDisposeScope ( ) ) {
304
- Current = collate_fn ( items , device ) ;
305
- currentDisposables = collate_scope . DisposablesView . ToList ( ) ;
306
- collate_scope . Detach ( currentDisposables ) ;
328
+ var tensors = new T [ indices . Length ] ;
329
+ Enumerable . Range ( 0 , indices . Length )
330
+ . AsParallel ( )
331
+ . WithDegreeOfParallelism ( loader . num_worker )
332
+ . ForAll ( ( i ) => {
333
+ tensors [ i ] = loader . dataset . GetTensor ( indices [ i ] ) ;
334
+ } ) ;
335
+
336
+ using var collate_scope = DisposeScopeManager . NewDisposeScope ( ) ;
337
+ current = loader . collate_fn ( tensors , loader . device ) ;
338
+
339
+ // TODO: Will be better if we have something like DetachAll
340
+ var view = collate_scope . DisposablesView ;
341
+ collate_scope . Detach ( view ) ;
342
+ if ( loader . disposeBatch ) {
343
+ this . currentDisposables = view ;
307
344
}
308
345
309
346
return true ;
310
-
311
- void ProcessPendingBatches ( )
312
- {
313
- while ( true ) {
314
- var idx = ScheduleBatch ( ) ;
315
- if ( idx is null ) break ;
316
- items [ idx . Value . Item1 ] = dataset . GetTensor ( idx . Value . Item2 ) ;
317
- }
318
- }
319
-
320
- ( int , long ) ? ScheduleBatch ( )
321
- {
322
- var t = Interlocked . Increment ( ref taskedBatchCount ) - 1 ;
323
- if ( t < tensorIndexList . Count )
324
- return ( t , tensorIndexList [ t ] ) ;
325
- return null ;
326
- }
327
347
}
328
348
}
329
349
@@ -333,42 +353,29 @@ void ProcessPendingBatches()
333
353
public void Reset ( )
334
354
{
335
355
DisposeCurrent ( ) ;
336
- if ( shuffle ) shuffler = shuffleEnumerable . GetEnumerator ( ) ;
337
- currentVal = - 1 ;
356
+ shuffler ? . Dispose ( ) ;
357
+ shuffler = loader . shuffler . GetEnumerator ( ) ;
338
358
}
339
359
360
+ S ? current ;
340
361
/// <summary>
341
362
/// Current tensor
342
363
/// </summary>
343
- public S Current { get ; private set ; }
364
+ public S Current => current ! ;
344
365
345
- object IEnumerator . Current => Current ;
366
+ object IEnumerator . Current => current ! ;
346
367
347
368
public void Dispose ( )
348
369
{
370
+ shuffler . Dispose ( ) ;
349
371
DisposeCurrent ( ) ;
350
372
}
351
373
352
374
private void DisposeCurrent ( )
353
375
{
354
- if ( currentDisposables is null ) return ;
355
- foreach ( var x in currentDisposables )
376
+ foreach ( var x in this . currentDisposables )
356
377
x . Dispose ( ) ;
357
- currentDisposables = null ;
358
- shuffler ? . Dispose ( ) ;
359
- }
360
- }
361
-
362
- public void Dispose ( )
363
- {
364
- Dispose ( true ) ;
365
- GC . SuppressFinalize ( this ) ;
366
- }
367
-
368
- protected virtual void Dispose ( bool disposing )
369
- {
370
- if ( disposing ) {
371
- dataset . Dispose ( ) ;
378
+ this . currentDisposables = Array . Empty < IDisposable > ( ) ;
372
379
}
373
380
}
374
381
}
0 commit comments