Skip to content

Commit 38e9667

Browse files
committed
doc
1 parent f3e7733 commit 38e9667

File tree

4 files changed

+130
-5
lines changed

4 files changed

+130
-5
lines changed

RELEASENOTES.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ __API Changes__:
1616
- #1291 `Tensor.grad()` and `Tensor.set_grad()` have been replaced by a new property `Tensor.grad`.
1717
- A potential memory leak caused by `set_grad` has been resolved.
1818
- `Include` method of dispose scopes has been removed. Use `Attach` instead.
19+
- Two more `Attach` methods that accepts `IEnumerable<IDisposable>`s and arrays as the parameter have been added into dispose scopes.
20+
- A new property `torch.CurrentDisposeScope` has been added to provide the ability to get the current dispose scope.
1921

2022
__Bug Fixes__:
2123

@@ -27,7 +29,7 @@ __Bug Fixes__:
2729
__Bug Fixes__:
2830

2931
- `TensorDataset` will now keep the aliases detached from dispose scopes, to avoid the unexpected disposal.
30-
- `DataLoaderEnumerator` has been completely rewritten to resolve the unexpected shuffler disposal, the ignorance of drop last and the incorrect count of worker.
32+
- `DataLoaderEnumerator` has been completely rewritten to resolve the unexpected shuffler disposal, the ignorance of `drop_last`, the incorrect count of worker, and the potential leak cause by multithreading.
3133
- #1303 Allow dispose scopes to be disposed out of LIFO order.
3234

3335
# NuGet Version 0.102.4

docfx/articles/memory.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,107 @@ public override Tensor entropy() => torch.WrappedTensorDisposeScope(() => ((scal
295295

296296
```
297297

298+
## Data loaders and datasets
299+
300+
Sometimes we'd like to train our model in the following pattern:
301+
302+
```csharp
303+
using var dataLoader = torch.utils.data.DataLoader(...);
304+
for (int epoch = 0; epoch < 100; epoch++)
305+
{
306+
foreach (var batch in dataLoader)
307+
{
308+
using (torch.NewDisposeScope())
309+
{
310+
...
311+
}
312+
}
313+
}
314+
```
315+
316+
In this case, you may notice that `batch` (at least the first batch) is created outside the dispose scope, which would cause a potential memory leak.
317+
318+
Of course we could manually dispose them. But actually we don't have to care about that, because the data loader will automatically dispose it before the next iteration.
319+
320+
However this might cause another problem. For example, we will get disposed tensors when using Linq. The behavior could be modified by setting `disposeBatch` to `false`:
321+
322+
```csharp
323+
using TorchSharp;
324+
325+
using var dataset = torch.utils.data.TensorDataset(torch.zeros([3]));
326+
327+
using var dataLoader1 = torch.utils.data.DataLoader(dataset, batchSize: 1);
328+
using var dataLoader2 = torch.utils.data.DataLoader(dataset, batchSize: 1, disposeBatch: false);
329+
330+
Console.WriteLine(dataLoader1.First()[0].IsInvalid); // True
331+
Console.WriteLine(dataLoader2.First()[0].IsInvalid); // False
332+
```
333+
334+
But those tensors would be detached from all the dispose scopes, even if the whole process is wrapped by a scope. (Otherwise it may lead to confusion since the iterations may not happen in the same dispose scope.) So don't forget to dispose them later or manually attach them to a scope. Also, be aware that enumerating the same `IEnumerable` twice could produce different instances:
335+
336+
```csharp
337+
// DON'T DO THIS:
338+
using TorchSharp;
339+
340+
using var dataset = torch.utils.data.TensorDataset(torch.zeros([3]));
341+
using var dataLoader = torch.utils.data.DataLoader(dataset, batchSize: 1, disposeBatch: false);
342+
343+
var tensors = dataLoader.Select(x => x[0]);
344+
DoSomeThing(tensors.ToArray());
345+
346+
foreach (var tensor in tensors)
347+
{
348+
tensor.Dispose();
349+
// DON'T DO THIS.
350+
// The tensor is not the one you have passed into `DoSomeThing`.
351+
}
352+
```
353+
354+
Meanwhile, if you want do write a dataset on your own, you shall notice that data loaders will dispose the tensors got from `GetTensor` after collation. So a dataset like this will not work because the saved tensor is disposed:
355+
356+
```csharp
357+
using TorchSharp;
358+
359+
using var dataLoader = torch.utils.data.DataLoader(new MyDataset(), batchSize: 1);
360+
foreach (var _ in dataLoader) ;
361+
// System.InvalidOperationException:
362+
// Tensor invalid -- empty handle.
363+
364+
class MyDataset : torch.utils.data.Dataset
365+
{
366+
private torch.Tensor tensor = torch.zeros([]);
367+
public override Dictionary<string, torch.Tensor> GetTensor(long index)
368+
{
369+
tensor = tensor + 1;
370+
// The new tensor is attached to the dispose scope in the data loader,
371+
// and it will be disposed after collation,
372+
// so in the next iteration it becomes invalid.
373+
return new() { ["tensor"] = tensor };
374+
}
375+
376+
public override long Count => 3;
377+
}
378+
```
379+
380+
Since the actual technique to "catch" the tensors is just a simple dispose scope. So we can write the class like this to avoid the disposal:
381+
382+
```csharp
383+
class MyDataset : torch.utils.data.Dataset
384+
{
385+
private torch.Tensor tensor = torch.zeros([]);
386+
public override Dictionary<string, torch.Tensor> GetTensor(long index)
387+
{
388+
var previous = tensor;
389+
tensor = (previous + 1).DetachFromDisposeScope();
390+
previous.Dispose(); // Don't forget to dispose the previous one.
391+
return new() { ["tensor"] = tensor };
392+
}
393+
394+
public override long Count => 3;
395+
}
396+
```
397+
398+
```
298399
299400
## Links and resources
300401

src/TorchSharp/DisposeScope.cs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,31 @@ public void Detach(IEnumerable<IDisposable> disposables)
218218
}
219219

220220
public void Attach(IDisposable disposable)
221+
{
222+
_ = Attach((IEnumerable<IDisposable>)new[] { disposable });
223+
}
224+
225+
public void Attach(params IDisposable[] disposables)
226+
{
227+
_ = Attach((IEnumerable<IDisposable>)disposables);
228+
}
229+
230+
public IReadOnlyList<IDisposable> Attach(IEnumerable<IDisposable> disposables)
221231
{
222232
if (this._disposeScopeManager is null)
223233
throw new ObjectDisposedException(this.GetType().FullName);
224-
if (disposable is torch.Tensor tensor) {
225-
if (tensor.OwningDisposeScope == null && !tensor.IsInvalid) {
226-
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
234+
235+
var result = new List<IDisposable>();
236+
foreach (var disposable in disposables) {
237+
if (disposable is torch.Tensor tensor) {
238+
if (tensor.OwningDisposeScope == null && !tensor.IsInvalid) {
239+
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
240+
}
227241
}
242+
AddToOther(this, disposable);
243+
result.Add(disposable);
228244
}
229-
AddToOther(this, disposable);
245+
return result;
230246
}
231247

232248
/// <summary>

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7377,6 +7377,12 @@ public static long max_int_value(ScalarType type)
73777377
public static DisposeScope NewDisposeScope() =>
73787378
DisposeScopeManager.ThreadSingleton.NewDisposeScope();
73797379

7380+
/// <summary>
7381+
/// Get the current dispose scope for the current thread.
7382+
/// </summary>
7383+
public static DisposeScope? CurrentDisposeScope =>
7384+
DisposeScopeManager.ThreadSingleton.CurrentDisposeScope;
7385+
73807386
/// <summary>
73817387
/// Creates a new dispose scope for the current thread, wrapping an expression.
73827388
/// </summary>

0 commit comments

Comments
 (0)