Skip to content

Commit 3d01e74

Browse files
Merge pull request #1309 from yueyinqiu/fix
more fixes
2 parents d333a18 + da689bf commit 3d01e74

File tree

5 files changed

+293
-63
lines changed

5 files changed

+293
-63
lines changed

RELEASENOTES.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ __API Changes__:
1616
- #1291 `Tensor.grad()` and `Tensor.set_grad()` have been replaced by a new property `Tensor.grad`.
1717
- A potential memory leak caused by `set_grad` has been resolved.
1818
- `Include` method of dispose scopes has been removed. Use `Attach` instead.
19+
- Two more `Attach` methods that accepts `IEnumerable<IDisposable>`s and arrays as the parameter have been added into dispose scopes.
20+
- A new property `torch.CurrentDisposeScope` has been added to provide the ability to get the current dispose scope.
1921

2022
__Bug Fixes__:
2123

@@ -27,7 +29,7 @@ __Bug Fixes__:
2729
__Bug Fixes__:
2830

2931
- `TensorDataset` will now keep the aliases detached from dispose scopes, to avoid the unexpected disposal.
30-
- `DataLoaderEnumerator` has been completely rewritten to resolve the unexpected shuffler disposal, the ignorance of drop last and the incorrect count of worker.
32+
- `DataLoaderEnumerator` has been completely rewritten to resolve the unexpected shuffler disposal, the ignorance of `drop_last`, the incorrect count of worker, and the potential leak cause by multithreading.
3133
- #1303 Allow dispose scopes to be disposed out of LIFO order.
3234

3335
# NuGet Version 0.102.4

docfx/articles/memory.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,107 @@ public override Tensor entropy() => torch.WrappedTensorDisposeScope(() => ((scal
295295

296296
```
297297

298+
## Data loaders and datasets
299+
300+
Sometimes we'd like to train our model in the following pattern:
301+
302+
```csharp
303+
using var dataLoader = torch.utils.data.DataLoader(...);
304+
for (int epoch = 0; epoch < 100; epoch++)
305+
{
306+
foreach (var batch in dataLoader)
307+
{
308+
using (torch.NewDisposeScope())
309+
{
310+
...
311+
}
312+
}
313+
}
314+
```
315+
316+
In this case, you may notice that `batch` (at least the first batch) is created outside the dispose scope, which would cause a potential memory leak.
317+
318+
Of course we could manually dispose them. But actually we don't have to care about that, because the data loader will automatically dispose it before the next iteration.
319+
320+
However this might cause another problem. For example, we will get disposed tensors when using Linq. The behavior could be modified by setting `disposeBatch` to `false`:
321+
322+
```csharp
323+
using TorchSharp;
324+
325+
using var dataset = torch.utils.data.TensorDataset(torch.zeros([3]));
326+
327+
using var dataLoader1 = torch.utils.data.DataLoader(dataset, batchSize: 1);
328+
using var dataLoader2 = torch.utils.data.DataLoader(dataset, batchSize: 1, disposeBatch: false);
329+
330+
Console.WriteLine(dataLoader1.First()[0].IsInvalid); // True
331+
Console.WriteLine(dataLoader2.First()[0].IsInvalid); // False
332+
```
333+
334+
But those tensors would be detached from all the dispose scopes, even if the whole process is wrapped by a scope. (Otherwise it may lead to confusion since the iterations may not happen in the same dispose scope.) So don't forget to dispose them later or manually attach them to a scope. Also, be aware that enumerating the same `IEnumerable` twice could produce different instances:
335+
336+
```csharp
337+
// DON'T DO THIS:
338+
using TorchSharp;
339+
340+
using var dataset = torch.utils.data.TensorDataset(torch.zeros([3]));
341+
using var dataLoader = torch.utils.data.DataLoader(dataset, batchSize: 1, disposeBatch: false);
342+
343+
var tensors = dataLoader.Select(x => x[0]);
344+
DoSomeThing(tensors.ToArray());
345+
346+
foreach (var tensor in tensors)
347+
{
348+
tensor.Dispose();
349+
// DON'T DO THIS.
350+
// The tensor is not the one you have passed into `DoSomeThing`.
351+
}
352+
```
353+
354+
Meanwhile, when writing a dataset on your own, it should be noticed that the data loaders will dispose the tensors created in `GetTensor` after collation. So a dataset like this will not work because the saved tensor will be disposed:
355+
356+
```csharp
357+
using TorchSharp;
358+
359+
using var dataLoader = torch.utils.data.DataLoader(new MyDataset(), batchSize: 1);
360+
foreach (var _ in dataLoader) ;
361+
// System.InvalidOperationException:
362+
// Tensor invalid -- empty handle.
363+
364+
class MyDataset : torch.utils.data.Dataset
365+
{
366+
private torch.Tensor tensor = torch.zeros([]);
367+
public override Dictionary<string, torch.Tensor> GetTensor(long index)
368+
{
369+
tensor = tensor + 1;
370+
// The new tensor is attached to the dispose scope in the data loader,
371+
// and it will be disposed after collation,
372+
// so in the next iteration it becomes invalid.
373+
return new() { ["tensor"] = tensor };
374+
}
375+
376+
public override long Count => 3;
377+
}
378+
```
379+
380+
Since the actual technique to "catch" the tensors is just a simple dispose scope. So we can write like this to avoid the disposal:
381+
382+
```csharp
383+
class MyDataset : torch.utils.data.Dataset
384+
{
385+
private torch.Tensor tensor = torch.zeros([]);
386+
public override Dictionary<string, torch.Tensor> GetTensor(long index)
387+
{
388+
var previous = tensor;
389+
tensor = (previous + 1).DetachFromDisposeScope();
390+
previous.Dispose(); // Don't forget to dispose the previous one.
391+
return new() { ["tensor"] = tensor };
392+
}
393+
394+
public override long Count => 3;
395+
}
396+
```
397+
398+
Also, if you want a "`Lazy`" collate function, do not directly save the tensors that are passed in. And `DetachFromDisposeScope` does not work in this case because they are kept in another list instead of dispose scopes, due to some multithreading issues. Instead, you could create aliases for them.
298399

299400
## Links and resources
300401

src/TorchSharp/DataLoader.cs

Lines changed: 155 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,64 @@ public static partial class utils
1818
public static partial class data
1919
{
2020

21-
public static Modules.DataLoader DataLoader(Dataset dataset, int batchSize, IEnumerable<long> shuffler, Device device = null, int num_worker = 1, bool drop_last = false)
21+
public static Modules.DataLoader DataLoader(
22+
Dataset dataset,
23+
int batchSize, IEnumerable<long> shuffler,
24+
Device device = null,
25+
int num_worker = 1, bool drop_last = false,
26+
bool disposeBatch = true, bool disposeDataset = true)
2227
{
23-
return new Modules.DataLoader(dataset, batchSize,shuffler, device, num_worker, drop_last);
28+
return new Modules.DataLoader(
29+
dataset,
30+
batchSize, shuffler,
31+
device,
32+
num_worker, drop_last,
33+
disposeBatch, disposeDataset);
2434
}
2535

26-
public static Modules.DataLoader DataLoader(Dataset dataset, int batchSize, bool shuffle = false, Device device = null, int? seed = null, int num_worker = 1, bool drop_last = false)
36+
public static Modules.DataLoader DataLoader(
37+
Dataset dataset,
38+
int batchSize, bool shuffle = false,
39+
Device device = null, int? seed = null,
40+
int num_worker = 1, bool drop_last = false,
41+
bool disposeBatch = true, bool disposeDataset = true)
2742
{
28-
return new Modules.DataLoader(dataset,batchSize,shuffle, device, seed, num_worker,drop_last);
43+
return new Modules.DataLoader(
44+
dataset,
45+
batchSize, shuffle,
46+
device, seed,
47+
num_worker, drop_last,
48+
disposeBatch, disposeDataset);
2949
}
3050

31-
public static Modules.IterableDataLoader DataLoader(IterableDataset dataset, int batchSize, IEnumerable<long> shuffler, Device device = null, int num_worker = 1, bool drop_last = false)
51+
public static Modules.IterableDataLoader DataLoader(
52+
IterableDataset dataset,
53+
int batchSize, IEnumerable<long> shuffler,
54+
Device device = null,
55+
int num_worker = 1, bool drop_last = false,
56+
bool disposeBatch = true, bool disposeDataset = true)
3257
{
33-
return new Modules.IterableDataLoader(dataset, batchSize, shuffler, device, num_worker, drop_last);
58+
return new Modules.IterableDataLoader(
59+
dataset,
60+
batchSize, shuffler,
61+
device,
62+
num_worker, drop_last,
63+
disposeBatch, disposeDataset);
3464
}
3565

36-
public static Modules.IterableDataLoader DataLoader(IterableDataset dataset, int batchSize, bool shuffle = false, Device device = null, int? seed = null, int num_worker = 1, bool drop_last = false)
66+
public static Modules.IterableDataLoader DataLoader(
67+
IterableDataset dataset,
68+
int batchSize, bool shuffle = false,
69+
Device device = null, int? seed = null,
70+
int num_worker = 1, bool drop_last = false,
71+
bool disposeBatch = true, bool disposeDataset = true)
3772
{
38-
return new Modules.IterableDataLoader(dataset, batchSize, shuffle, device, seed, num_worker, drop_last);
73+
return new Modules.IterableDataLoader(
74+
dataset,
75+
batchSize, shuffle,
76+
device, seed,
77+
num_worker, drop_last,
78+
disposeBatch, disposeDataset);
3979
}
4080
}
4181
}
@@ -64,8 +104,23 @@ public class DataLoader : DataLoader<Dictionary<string, torch.Tensor>, Dictionar
64104
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
65105
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
66106
/// </param>
67-
public DataLoader(Dataset dataset, int batchSize, IEnumerable<long> shuffler, Device device = null, int num_worker = 1, bool drop_last = false)
68-
: base(dataset, batchSize, Collate, shuffler, device, num_worker, drop_last)
107+
/// <param name="disposeBatch">
108+
/// Indicates whether to automatically dispose the collated tensors after an iteration.
109+
/// </param>
110+
/// <param name="disposeDataset">
111+
/// Indicates whether to dispose the dataset when being disposed.
112+
/// </param>
113+
public DataLoader(
114+
Dataset dataset,
115+
int batchSize, IEnumerable<long> shuffler,
116+
Device device = null,
117+
int num_worker = 1, bool drop_last = false,
118+
bool disposeBatch = true, bool disposeDataset = true)
119+
: base(dataset,
120+
batchSize, Collate, shuffler,
121+
device,
122+
num_worker, drop_last,
123+
disposeBatch, disposeDataset)
69124
{
70125
}
71126

@@ -82,8 +137,23 @@ public DataLoader(Dataset dataset, int batchSize, IEnumerable<long> shuffler, De
82137
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
83138
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
84139
/// </param>
85-
public DataLoader(Dataset dataset, int batchSize, bool shuffle = false, Device device = null, int? seed = null, int num_worker = 1, bool drop_last = false)
86-
: base(dataset, batchSize, Collate, shuffle, device, seed, num_worker, drop_last)
140+
/// <param name="disposeBatch">
141+
/// Indicates whether to automatically dispose the collated tensors after an iteration.
142+
/// </param>
143+
/// <param name="disposeDataset">
144+
/// Indicates whether to dispose the dataset when being disposed.
145+
/// </param>
146+
public DataLoader(
147+
Dataset dataset,
148+
int batchSize, bool shuffle = false,
149+
Device device = null, int? seed = null,
150+
int num_worker = 1, bool drop_last = false,
151+
bool disposeBatch = true, bool disposeDataset = true)
152+
: base(dataset,
153+
batchSize, Collate, shuffle,
154+
device, seed,
155+
num_worker, drop_last,
156+
disposeBatch, disposeDataset)
87157
{
88158
}
89159

@@ -120,8 +190,23 @@ public class IterableDataLoader : DataLoader<IList<torch.Tensor>, IList<torch.Te
120190
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
121191
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
122192
/// </param>
123-
public IterableDataLoader(IterableDataset dataset, int batchSize, IEnumerable<long> shuffler, Device device = null, int num_worker = 1, bool drop_last = false)
124-
: base(dataset, batchSize, Collate, shuffler, device, num_worker, drop_last)
193+
/// <param name="disposeBatch">
194+
/// Indicates whether to automatically dispose the collated tensors after an iteration.
195+
/// </param>
196+
/// <param name="disposeDataset">
197+
/// Indicates whether to dispose the dataset when being disposed.
198+
/// </param>
199+
public IterableDataLoader(
200+
IterableDataset dataset,
201+
int batchSize, IEnumerable<long> shuffler,
202+
Device device = null,
203+
int num_worker = 1, bool drop_last = false,
204+
bool disposeBatch = true, bool disposeDataset = true)
205+
: base(dataset,
206+
batchSize, Collate, shuffler,
207+
device,
208+
num_worker, drop_last,
209+
disposeBatch, disposeDataset)
125210
{
126211
}
127212

@@ -138,8 +223,23 @@ public IterableDataLoader(IterableDataset dataset, int batchSize, IEnumerable<lo
138223
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
139224
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
140225
/// </param>
141-
public IterableDataLoader(IterableDataset dataset, int batchSize, bool shuffle = false, Device device = null, int? seed = null, int num_worker = 1, bool drop_last = false)
142-
: base(dataset, batchSize, Collate, shuffle, device, seed, num_worker, drop_last)
226+
/// <param name="disposeBatch">
227+
/// Indicates whether to automatically dispose the collated tensors after an iteration.
228+
/// </param>
229+
/// <param name="disposeDataset">
230+
/// Indicates whether to dispose the dataset when being disposed.
231+
/// </param>
232+
public IterableDataLoader(
233+
IterableDataset dataset,
234+
int batchSize, bool shuffle = false,
235+
Device device = null, int? seed = null,
236+
int num_worker = 1, bool drop_last = false,
237+
bool disposeBatch = true, bool disposeDataset = true)
238+
: base(dataset,
239+
batchSize, Collate, shuffle,
240+
device, seed,
241+
num_worker, drop_last,
242+
disposeBatch, disposeDataset)
143243
{
144244
}
145245

@@ -296,12 +396,11 @@ public DataLoaderEnumerator(DataLoader<T, S> loader)
296396
Reset();
297397
}
298398

299-
private long? MoveNextValue()
399+
private static void DisposeAll(HashSet<IDisposable> disposables)
300400
{
301-
if (!shuffler.MoveNext()) {
302-
return null;
401+
foreach (var disposable in disposables) {
402+
disposable.Dispose();
303403
}
304-
return shuffler.Current;
305404
}
306405

307406
/// <summary>
@@ -312,35 +411,42 @@ public bool MoveNext()
312411
{
313412
DisposeCurrent();
314413

315-
using (var scope = torch.NewDisposeScope()) {
316-
var indices = Enumerable.Range(0, loader.batchSize)
317-
.Select(_ => MoveNextValue())
318-
.Where(x => x.HasValue)
319-
.Cast<long>()
320-
.ToArray();
321-
if (indices.Length is 0)
322-
return false;
323-
if (loader.drop_last && indices.Length < loader.batchSize) {
324-
return false;
325-
}
326-
327-
var tensors = new T[indices.Length];
328-
Enumerable.Range(0, indices.Length)
329-
.AsParallel()
330-
.WithDegreeOfParallelism(loader.num_worker)
331-
.ForAll((i) => {
332-
tensors[i] = loader.dataset.GetTensor(indices[i]);
333-
});
334-
335-
using var collate_scope = torch.NewDisposeScope();
336-
current = loader.collate_fn(tensors, loader.device);
337-
var disposables = collate_scope.DetachAllAndDispose();
338-
if (loader.disposeBatch) {
339-
this.currentDisposables = disposables;
340-
}
341-
342-
return true;
414+
var indices = Enumerable.Range(0, loader.batchSize)
415+
.Select(_ => shuffler.MoveNext() ? shuffler.Current : (long?)null)
416+
.Where(x => x.HasValue)
417+
.Cast<long>()
418+
.ToArray();
419+
420+
if (indices.Length is 0)
421+
return false;
422+
423+
if (loader.drop_last && indices.Length < loader.batchSize) {
424+
return false;
425+
}
426+
427+
var tensors = new T[indices.Length];
428+
var getTensorDisposables = new HashSet<IDisposable>[indices.Length];
429+
Enumerable.Range(0, indices.Length)
430+
.AsParallel()
431+
.WithDegreeOfParallelism(loader.num_worker)
432+
.ForAll((i) => {
433+
using var getTensorScope = torch.NewDisposeScope();
434+
tensors[i] = loader.dataset.GetTensor(indices[i]);
435+
getTensorDisposables[i] = getTensorScope.DetachAllAndDispose();
436+
});
437+
438+
using var collateScope = torch.NewDisposeScope();
439+
this.current = loader.collate_fn(tensors, loader.device);
440+
var collateDisposables = collateScope.DetachAllAndDispose();
441+
if (loader.disposeBatch) {
442+
this.currentDisposables = collateDisposables;
443+
}
444+
445+
foreach (var set in getTensorDisposables) {
446+
DisposeAll(set);
343447
}
448+
449+
return true;
344450
}
345451

346452
/// <summary>
@@ -370,8 +476,7 @@ public void Dispose()
370476
private void DisposeCurrent()
371477
{
372478
if (this.currentDisposables is not null) {
373-
foreach (var x in this.currentDisposables)
374-
x.Dispose();
479+
DisposeAll(this.currentDisposables);
375480
this.currentDisposables = null;
376481
}
377482
}

0 commit comments

Comments
 (0)