@@ -89,14 +89,16 @@ public DataLoader(Dataset dataset, int batchSize, bool shuffle = false, Device d
89
89
90
90
private static Dictionary < string , torch . Tensor > Collate ( IEnumerable < Dictionary < string , torch . Tensor > > dic , torch . Device device )
91
91
{
92
- Dictionary < string , torch . Tensor > batch = new ( ) ;
93
- foreach ( var x in dic . First ( ) . Keys ) {
94
- var t = cat ( dic . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
95
- if ( t . device_type != device . type || t . device_index != device . index )
96
- t = t . to ( device ) ;
97
- batch [ x ] = t ;
92
+ using ( torch . NewDisposeScope ( ) ) {
93
+ Dictionary < string , torch . Tensor > batch = new ( ) ;
94
+ foreach ( var x in dic . First ( ) . Keys ) {
95
+ var t = cat ( dic . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
96
+ if ( t . device_type != device . type || t . device_index != device . index )
97
+ t = t . to ( device ) ;
98
+ batch [ x ] = t . MoveToOuterDisposeScope ( ) ;
99
+ }
100
+ return batch ;
98
101
}
99
- return batch ;
100
102
}
101
103
}
102
104
@@ -143,14 +145,16 @@ public IterableDataLoader(IterableDataset dataset, int batchSize, bool shuffle =
143
145
144
146
private static IList < torch . Tensor > Collate ( IEnumerable < IList < torch . Tensor > > dic , torch . Device device )
145
147
{
146
- List < torch . Tensor > batch = new ( ) ;
147
- for ( var x = 0 ; x < dic . First ( ) . Count ; x ++ ) {
148
- var t = cat ( dic . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
149
- if ( t . device_type != device . type || t . device_index != device . index )
150
- t = t . to ( device ) ;
151
- batch . Add ( t ) ;
148
+ using ( torch . NewDisposeScope ( ) ) {
149
+ List < torch . Tensor > batch = new ( ) ;
150
+ for ( var x = 0 ; x < dic . First ( ) . Count ; x ++ ) {
151
+ var t = cat ( dic . Select ( k => k [ x ] . unsqueeze ( 0 ) ) . ToArray ( ) , 0 ) ;
152
+ if ( t . device_type != device . type || t . device_index != device . index )
153
+ t = t . to ( device ) ;
154
+ batch . Add ( t . MoveToOuterDisposeScope ( ) ) ;
155
+ }
156
+ return batch ;
152
157
}
153
- return batch ;
154
158
}
155
159
}
156
160
@@ -167,6 +171,7 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
167
171
private IEnumerable < long > shuffler ;
168
172
private int num_worker ;
169
173
private Func < IEnumerable < T > , torch . Device , S > collate_fn ;
174
+ private bool autoDispose ;
170
175
171
176
/// <summary>
172
177
/// Pytorch style dataloader
@@ -181,7 +186,18 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
181
186
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
182
187
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
183
188
/// </param>
184
- public DataLoader ( Dataset < T > dataset , int batchSize , Func < IEnumerable < T > , torch . Device , S > collate_fn , IEnumerable < long > shuffler , Device device = null , int num_worker = 1 , bool drop_last = false )
189
+ /// <param name="autoDispose">
190
+ /// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
191
+ /// </param>
192
+ public DataLoader (
193
+ Dataset < T > dataset ,
194
+ int batchSize ,
195
+ Func < IEnumerable < T > , torch . Device , S > collate_fn ,
196
+ IEnumerable < long > shuffler ,
197
+ Device device = null ,
198
+ int num_worker = 1 ,
199
+ bool drop_last = false ,
200
+ bool autoDispose = true )
185
201
{
186
202
this . dataset = dataset ;
187
203
this . batchSize = batchSize ;
@@ -191,6 +207,7 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
191
207
this . shuffler = shuffler ;
192
208
this . num_worker = num_worker ;
193
209
this . collate_fn = collate_fn ;
210
+ this . autoDispose = autoDispose ;
194
211
}
195
212
196
213
/// <summary>
@@ -207,7 +224,19 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
207
224
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
208
225
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
209
226
/// </param>
210
- public DataLoader ( Dataset < T > dataset , int batchSize , Func < IEnumerable < T > , torch . Device , S > collate_fn , bool shuffle = false , Device device = null , int ? seed = null , int num_worker = 1 , bool drop_last = false )
227
+ /// <param name="autoDispose">
228
+ /// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
229
+ /// </param>
230
+ public DataLoader (
231
+ Dataset < T > dataset ,
232
+ int batchSize ,
233
+ Func < IEnumerable < T > , torch . Device , S > collate_fn ,
234
+ bool shuffle = false ,
235
+ Device device = null ,
236
+ int ? seed = null ,
237
+ int num_worker = 1 ,
238
+ bool drop_last = false ,
239
+ bool autoDispose = true )
211
240
{
212
241
this . dataset = dataset ;
213
242
this . batchSize = batchSize ;
@@ -217,14 +246,17 @@ public DataLoader(Dataset<T> dataset, int batchSize, Func<IEnumerable<T>, torch.
217
246
this . shuffler = seed is null ? new FisherYatesShuffler ( dataset . Count ) : new FisherYatesShuffler ( dataset . Count , seed ) ;
218
247
this . num_worker = num_worker ;
219
248
this . collate_fn = collate_fn ;
249
+ this . autoDispose = autoDispose ;
220
250
}
221
251
222
252
/// <summary>
223
253
/// Generate enumerator
224
254
/// </summary>
225
255
/// <returns>Enumerator for batch</returns>
226
256
public IEnumerator < S > GetEnumerator ( ) =>
227
- new DataLoaderEnumerator ( dataset , batchSize , shuffle , device , shuffler , num_worker , collate_fn ) ;
257
+ new DataLoaderEnumerator (
258
+ dataset , batchSize , shuffle , device ,
259
+ shuffler , num_worker , collate_fn , autoDispose ) ;
228
260
229
261
IEnumerator IEnumerable . GetEnumerator ( ) => GetEnumerator ( ) ;
230
262
@@ -243,9 +275,17 @@ private class DataLoaderEnumerator : IEnumerator<S>
243
275
private IEnumerator < long > shuffler ;
244
276
private long currentVal = 0 ;
245
277
private int num_worker = 0 ;
246
- private IList < IDisposable > currentDisposables ;
278
+ private List < IDisposable > currentDisposables ;
247
279
private Func < IEnumerable < T > , torch . Device , S > collate_fn ;
248
- public DataLoaderEnumerator ( Dataset < T > dataset , int batchSize , bool shuffle , Device device , IEnumerable < long > shuffleEnumerable , int num_worker , Func < IEnumerable < T > , torch . Device , S > collate_fn )
280
+ public DataLoaderEnumerator (
281
+ Dataset < T > dataset ,
282
+ int batchSize ,
283
+ bool shuffle ,
284
+ Device device ,
285
+ IEnumerable < long > shuffleEnumerable ,
286
+ int num_worker ,
287
+ Func < IEnumerable < T > , torch . Device , S > collate_fn ,
288
+ bool autoDispose )
249
289
{
250
290
this . dataset = dataset ;
251
291
this . batchSize = batchSize ;
@@ -255,6 +295,7 @@ public DataLoaderEnumerator(Dataset<T> dataset, int batchSize, bool shuffle, Dev
255
295
if ( num_worker < 1 ) num_worker = 1 ;
256
296
this . num_worker = num_worker ;
257
297
this . collate_fn = collate_fn ;
298
+ this . currentDisposables = autoDispose ? new List < IDisposable > ( ) : null ;
258
299
Reset ( ) ;
259
300
}
260
301
@@ -300,10 +341,19 @@ public bool MoveNext()
300
341
foreach ( var task in tasks )
301
342
task . Wait ( ) ;
302
343
303
- using ( var collate_scope = DisposeScopeManager . NewDisposeScope ( ) ) {
344
+ if ( this . currentDisposables is not null ) {
345
+ using ( var collate_scope = DisposeScopeManager . NewDisposeScope ( ) ) {
346
+ Current = collate_fn ( items , device ) ;
347
+ currentDisposables . AddRange ( collate_scope . DisposablesView ) ;
348
+ collate_scope . Detach ( currentDisposables ) ;
349
+ }
350
+ }
351
+ else {
304
352
Current = collate_fn ( items , device ) ;
305
- currentDisposables = collate_scope . DisposablesView . ToList ( ) ;
306
- collate_scope . Detach ( currentDisposables ) ;
353
+ }
354
+
355
+ foreach ( var item in items ) {
356
+ dataset . DisposeTensor ( item ) ;
307
357
}
308
358
309
359
return true ;
@@ -354,7 +404,7 @@ private void DisposeCurrent()
354
404
if ( currentDisposables is null ) return ;
355
405
foreach ( var x in currentDisposables )
356
406
x . Dispose ( ) ;
357
- currentDisposables = null ;
407
+ currentDisposables . Clear ( ) ;
358
408
shuffler ? . Dispose ( ) ;
359
409
}
360
410
}
0 commit comments