File tree Expand file tree Collapse file tree 3 files changed +29
-3
lines changed Expand file tree Collapse file tree 3 files changed +29
-3
lines changed Original file line number Diff line number Diff line change @@ -352,6 +352,10 @@ public bool MoveNext()
352
352
Current = collate_fn ( items , device ) ;
353
353
}
354
354
355
+ foreach ( var item in items ) {
356
+ dataset . DisposeTensor ( item ) ;
357
+ }
358
+
355
359
return true ;
356
360
357
361
void ProcessPendingBatches ( )
Original file line number Diff line number Diff line change @@ -15,13 +15,25 @@ public static partial class data
15
15
/// </summary>
16
16
public abstract class Dataset : Dataset < Dictionary < string , torch . Tensor > >
17
17
{
18
+ public override void DisposeTensor ( Dictionary < string , Tensor > tensor )
19
+ {
20
+ foreach ( var t in tensor . Values ) {
21
+ t . Dispose ( ) ;
22
+ }
23
+ }
18
24
}
19
25
20
26
/// <summary>
21
27
/// Iterable-style data sets
22
28
/// </summary>
23
29
public abstract class IterableDataset : Dataset < IList < Tensor > >
24
30
{
31
+ public override void DisposeTensor ( IList < Tensor > tensor )
32
+ {
33
+ foreach ( var t in tensor ) {
34
+ t . Dispose ( ) ;
35
+ }
36
+ }
25
37
}
26
38
27
39
/// <summary>
@@ -47,6 +59,8 @@ public void Dispose()
47
59
/// <returns>Tensors of index. DataLoader will catenate these tensors into batches.</returns>
48
60
public abstract T GetTensor ( long index ) ;
49
61
62
+ public abstract void DisposeTensor ( T tensor ) ;
63
+
50
64
protected virtual void Dispose ( bool disposing )
51
65
{
52
66
}
Original file line number Diff line number Diff line change @@ -37,7 +37,7 @@ internal TensorDataset(torch.Tensor[] tensors)
37
37
long size1 = tensors [ 0 ] . shape [ 0 ] ;
38
38
if ( ! tensors . All ( t => t . shape [ 0 ] == size1 ) ) throw new ArgumentException ( "All tensors must have the same first dimension size." , nameof ( tensors ) ) ;
39
39
40
- _tensors . AddRange ( tensors ) ;
40
+ _tensors = tensors . Select ( x => x . alias ( ) ) . ToArray ( ) ;
41
41
}
42
42
43
43
/// <summary>
@@ -62,8 +62,16 @@ public override long Count {
62
62
return this [ index ] ;
63
63
}
64
64
65
- private List < torch . Tensor > _tensors = new List < torch . Tensor > ( ) ;
66
- }
65
+ private torch . Tensor [ ] _tensors ;
67
66
67
+ protected override void Dispose ( bool disposing )
68
+ {
69
+ if ( disposing ) {
70
+ foreach ( var tensor in _tensors ) {
71
+ tensor . Dispose ( ) ;
72
+ }
73
+ }
74
+ }
75
+ }
68
76
}
69
77
}
You can’t perform that action at this time.
0 commit comments