Skip to content

Commit 5be4440

Browse files
committed
add disposetensor in dataset
1 parent 6b7c873 commit 5be4440

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

src/TorchSharp/DataLoader.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,10 @@ public bool MoveNext()
352352
Current = collate_fn(items, device);
353353
}
354354

355+
foreach (var item in items) {
356+
dataset.DisposeTensor(item);
357+
}
358+
355359
return true;
356360

357361
void ProcessPendingBatches()

src/TorchSharp/Dataset.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,25 @@ public static partial class data
1515
/// </summary>
1616
public abstract class Dataset : Dataset<Dictionary<string, torch.Tensor>>
1717
{
18+
public override void DisposeTensor(Dictionary<string, Tensor> tensor)
19+
{
20+
foreach (var t in tensor.Values) {
21+
t.Dispose();
22+
}
23+
}
1824
}
1925

2026
/// <summary>
2127
/// Iterable-style data sets
2228
/// </summary>
2329
public abstract class IterableDataset : Dataset<IList<Tensor>>
2430
{
31+
public override void DisposeTensor(IList<Tensor> tensor)
32+
{
33+
foreach (var t in tensor) {
34+
t.Dispose();
35+
}
36+
}
2537
}
2638

2739
/// <summary>
@@ -47,6 +59,8 @@ public void Dispose()
4759
/// <returns>Tensors of index. DataLoader will catenate these tensors into batches.</returns>
4860
public abstract T GetTensor(long index);
4961

62+
public abstract void DisposeTensor(T tensor);
63+
5064
protected virtual void Dispose(bool disposing)
5165
{
5266
}

src/TorchSharp/Utils/TensorDataset.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ internal TensorDataset(torch.Tensor[] tensors)
3737
long size1 = tensors[0].shape[0];
3838
if (!tensors.All(t => t.shape[0] == size1)) throw new ArgumentException("All tensors must have the same first dimension size.", nameof(tensors));
3939

40-
_tensors.AddRange(tensors);
40+
_tensors = tensors.Select(x => x.alias()).ToArray();
4141
}
4242

4343
/// <summary>
@@ -62,8 +62,16 @@ public override long Count {
6262
return this[index];
6363
}
6464

65-
private List<torch.Tensor> _tensors = new List<torch.Tensor>();
66-
}
65+
private torch.Tensor[] _tensors;
6766

67+
protected override void Dispose(bool disposing)
68+
{
69+
if (disposing) {
70+
foreach (var tensor in _tensors) {
71+
tensor.Dispose();
72+
}
73+
}
74+
}
75+
}
6876
}
6977
}

0 commit comments

Comments
 (0)