Skip to content

Commit b6f55da

Browse files
Merge pull request #1318 from NiklasGustafsson/missing
Add a input-free version of pre- and post-hooks that can manipulate t…
2 parents 06a83e3 + b5be13e commit b6f55da

File tree

3 files changed

+134
-26
lines changed

3 files changed

+134
-26
lines changed

RELEASENOTES.md

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,33 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the
66

77
__Breaking Changes__:
88

9-
- `torchvision.dataset.MNIST` will try more mirrors.
10-
- The thrown exception might be changed when it fails to download `MNIST`, `FashionMNIST` or `KMNIST`.
11-
- `ObjectDisposedException` will now be thrown when trying to use the disposed dispose scopes.
12-
- The constructor of dispose scopes is no longer `public`. Use `torch.NewDisposeScope` instead.
9+
`torchvision.dataset.MNIST` will try more mirrors. The thrown exception might be changed when it fails to download `MNIST`, `FashionMNIST` or `KMNIST`.<br/>
10+
`ObjectDisposedException` will now be thrown when trying to use a disposed dispose scopes.<br/>
11+
The constructor of dispose scopes is no longer `public`. Use `torch.NewDisposeScope` instead.<br/>
1312

1413
__API Changes__:
1514

16-
- #1291 `Tensor.grad()` and `Tensor.set_grad()` have been replaced by a new property `Tensor.grad`.
17-
- A potential memory leak caused by `set_grad` has been resolved.
18-
- `Include` method of dispose scopes has been removed. Use `Attach` instead.
19-
- Two more `Attach` methods that accepts `IEnumerable<IDisposable>`s and arrays as the parameter have been added into dispose scopes.
20-
- A new property `torch.CurrentDisposeScope` has been added to provide the ability to get the current dispose scope.
15+
#1314 Grant read-only access to DataLoader attributes<br/>
16+
#1313 Add 'non_blocking' argument to tensor and module 'to()' signatures.<br/>
17+
#1291 `Tensor.grad()` and `Tensor.set_grad()` have been replaced by a new property `Tensor.grad`.<br/>
18+
A potential memory leak caused by `set_grad` has been resolved.<br/>
19+
`Include` method of dispose scopes has been removed. Use `Attach` instead.<br/>
20+
Two more `Attach` methods that accepts `IEnumerable<IDisposable>`s and arrays as the parameter have been added into dispose scopes.<br/>
21+
A new property `torch.CurrentDisposeScope` has been added to provide the ability to get the current dispose scope.<br/>
22+
Add module hooks that take no input/output arguments, just the module itself.<br/>
2123

2224
__Bug Fixes__:
2325

24-
- #1300 `Adadelta`, `Adam` and `AdamW` will no longer throw `NullReferenceException` when `maximize` is `true` and `grad` is `null`.
25-
- `torch.normal` will now correctly return a leaf tensor.
26-
- New options `disposeBatch` and `disposeDataset` have been added into `DataLoader`.
27-
- The default collate functions will now always dispose the intermediate tensors, rather than wait for the next iteration.
28-
- The fields are now exposed as readonly properties.
26+
#1300 `Adadelta`, `Adam` and `AdamW` will no longer throw `NullReferenceException` when `maximize` is `true` and `grad` is `null`.<br/>
27+
torch.normal` will now correctly return a leaf tensor.<br/>
28+
New options `disposeBatch` and `disposeDataset` have been added into `DataLoader`.<br/>
29+
The default collate functions will now always dispose the intermediate tensors, rather than wait for the next iteration.<br/>
2930

3031
__Bug Fixes__:
3132

32-
- `TensorDataset` will now keep the aliases detached from dispose scopes, to avoid the unexpected disposal.
33-
- `DataLoaderEnumerator` has been completely rewritten to resolve the unexpected shuffler disposal, the ignorance of `drop_last`, the incorrect count of worker, and the potential leak cause by multithreading.
34-
- #1303 Allow dispose scopes to be disposed out of LIFO order.
33+
`TensorDataset` will now keep the aliases detached from dispose scopes, to avoid the unexpected disposal.<br/>
34+
`DataLoaderEnumerator` has been completely rewritten to resolve the unexpected shuffler disposal, the ignorance of `drop_last`, the incorrect count of worker, and the potential leak cause by multithreading.<br/>
35+
#1303 Allow dispose scopes to be disposed out of LIFO order.<br/>
3536

3637
# NuGet Version 0.102.4
3738

src/TorchSharp/NN/Module.cs

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ public interface IModule<T1, T2, T3, T4, T5, T6, TResult>
12281228
/// <summary>
12291229
/// Represents a module that accepts 'hook' to the module logic.
12301230
/// </summary>
1231-
public class HookableModule<TPreHook,TPostHook> : Module
1231+
public class HookableModule<TModule,TPreHook,TPostHook> : Module
12321232
{
12331233
protected HookableModule(string name) : base(name) { }
12341234

@@ -1249,22 +1249,41 @@ public HookRemover register_forward_pre_hook(TPreHook hook)
12491249
return new HookRemover(this, key);
12501250
}
12511251

1252+
public HookRemover register_forward_hook(Action<TModule> hook)
1253+
{
1254+
var key = Guid.NewGuid().ToString();
1255+
module_post_hooks.Add(key, hook);
1256+
return new HookRemover(this, key);
1257+
}
1258+
1259+
public HookRemover register_forward_pre_hook(Action<TModule> hook)
1260+
{
1261+
var key = Guid.NewGuid().ToString();
1262+
module_pre_hooks.Add(key, hook);
1263+
return new HookRemover(this, key);
1264+
}
1265+
12521266
private void remove(string key)
12531267
{
12541268
if (pre_hooks.ContainsKey(key)) pre_hooks.Remove(key);
12551269
if (post_hooks.ContainsKey(key)) post_hooks.Remove(key);
1270+
if (module_pre_hooks.ContainsKey(key)) module_pre_hooks.Remove(key);
1271+
if (module_post_hooks.ContainsKey(key)) module_post_hooks.Remove(key);
12561272
}
12571273

12581274
protected Dictionary<string, TPreHook> pre_hooks = new Dictionary<string, TPreHook>();
12591275
protected Dictionary<string, TPostHook> post_hooks = new Dictionary<string, TPostHook>();
12601276

1277+
protected Dictionary<string, Action<TModule>> module_pre_hooks = new Dictionary<string, Action<TModule>>();
1278+
protected Dictionary<string, Action<TModule>> module_post_hooks = new Dictionary<string, Action<TModule>>();
1279+
12611280
/// <summary>
12621281
/// Used to remove a specific hook, following the PyTorch API design.
12631282
/// </summary>
12641283
/// <remarks>The name and namespace of this class is not the same as in PyTorch, but serves the same purpose.</remarks>
12651284
public class HookRemover
12661285
{
1267-
public HookRemover(HookableModule<TPreHook, TPostHook> module, string key)
1286+
public HookRemover(HookableModule<TModule,TPreHook, TPostHook> module, string key)
12681287
{
12691288
this.module = module;
12701289
this.key = key;
@@ -1275,7 +1294,7 @@ public void remove()
12751294
module.remove(key);
12761295
}
12771296

1278-
private HookableModule<TPreHook, TPostHook> module;
1297+
private HookableModule<TModule,TPreHook, TPostHook> module;
12791298
private string key;
12801299
}
12811300
}
@@ -1285,7 +1304,7 @@ public void remove()
12851304
/// </summary>
12861305
/// <typeparam name="T">The argument type of the module's forward() function.</typeparam>
12871306
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1288-
public abstract class Module<T, TResult> : HookableModule<Func<Module<T,TResult>, T, T>, Func<Module<T, TResult>, T, TResult, TResult>>, IModule<T, TResult>
1307+
public abstract class Module<T, TResult> : HookableModule<Module<T, TResult>, Func<Module<T,TResult>, T, T>, Func<Module<T, TResult>, T, TResult, TResult>>, IModule<T, TResult>
12891308
{
12901309
protected Module(string name) : base(name) { }
12911310
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1305,6 +1324,10 @@ public TResult call(T input)
13051324
{
13061325
// Call pre-hooks, if available.
13071326

1327+
foreach (var hook in module_pre_hooks.Values) {
1328+
hook(this);
1329+
}
1330+
13081331
foreach (var hook in pre_hooks.Values) {
13091332
var modified = hook(this, input);
13101333
if (modified is not null)
@@ -1321,6 +1344,10 @@ public TResult call(T input)
13211344
result = modified;
13221345
}
13231346

1347+
foreach (var hook in module_post_hooks.Values) {
1348+
hook(this);
1349+
}
1350+
13241351
return result;
13251352
}
13261353
}
@@ -1331,7 +1358,7 @@ public TResult call(T input)
13311358
/// <typeparam name="T1">The first argument type of the module's forward() function.</typeparam>
13321359
/// <typeparam name="T2">The second argument type of the module's forward() function.</typeparam>
13331360
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1334-
public abstract class Module<T1, T2, TResult> : HookableModule<Func<Module<T1, T2, TResult>, T1, T2, (T1, T2)?>, Func<Module<T1, T2, TResult>, T1, T2, TResult, TResult>>, IModule<T1, T2, TResult>
1361+
public abstract class Module<T1, T2, TResult> : HookableModule<Module<T1, T2, TResult>, Func<Module<T1, T2, TResult>, T1, T2, (T1, T2)?>, Func<Module<T1, T2, TResult>, T1, T2, TResult, TResult>>, IModule<T1, T2, TResult>
13351362
{
13361363
protected Module(string name) : base(name) { }
13371364
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1351,6 +1378,10 @@ public TResult call(T1 input1, T2 input2)
13511378
{
13521379
// Call pre-hooks, if available.
13531380

1381+
foreach (var hook in module_pre_hooks.Values) {
1382+
hook(this);
1383+
}
1384+
13541385
foreach (var hook in pre_hooks.Values) {
13551386
var modified = hook(this, input1, input2);
13561387
if (modified.HasValue) {
@@ -1369,6 +1400,10 @@ public TResult call(T1 input1, T2 input2)
13691400
result = modified;
13701401
}
13711402

1403+
foreach (var hook in module_post_hooks.Values) {
1404+
hook(this);
1405+
}
1406+
13721407
return result;
13731408
}
13741409
}
@@ -1380,7 +1415,7 @@ public TResult call(T1 input1, T2 input2)
13801415
/// <typeparam name="T2">The second argument type of the module's forward() function.</typeparam>
13811416
/// <typeparam name="T3">The third argument type of the module's forward() function.</typeparam>
13821417
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1383-
public abstract class Module<T1, T2, T3, TResult> : HookableModule<Func<Module<T1, T2, T3, TResult>, T1, T2, T3, (T1, T2, T3)?>, Func<Module<T1, T2, T3, TResult>, T1, T2, T3, TResult, TResult>>, IModule<T1, T2, T3, TResult>
1418+
public abstract class Module<T1, T2, T3, TResult> : HookableModule<Module<T1, T2, T3, TResult>,Func<Module<T1, T2, T3, TResult>, T1, T2, T3, (T1, T2, T3)?>, Func<Module<T1, T2, T3, TResult>, T1, T2, T3, TResult, TResult>>, IModule<T1, T2, T3, TResult>
13841419
{
13851420
protected Module(string name) : base(name) { }
13861421
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1400,6 +1435,10 @@ public TResult call(T1 input1, T2 input2, T3 input3)
14001435
{
14011436
// Call pre-hooks, if available.
14021437

1438+
foreach (var hook in module_pre_hooks.Values) {
1439+
hook(this);
1440+
}
1441+
14031442
foreach (var hook in pre_hooks.Values) {
14041443
var modified = hook(this, input1, input2, input3);
14051444
if (modified.HasValue) {
@@ -1419,6 +1458,10 @@ public TResult call(T1 input1, T2 input2, T3 input3)
14191458
result = modified;
14201459
}
14211460

1461+
foreach (var hook in module_post_hooks.Values) {
1462+
hook(this);
1463+
}
1464+
14221465
return result;
14231466
}
14241467
}
@@ -1431,7 +1474,7 @@ public TResult call(T1 input1, T2 input2, T3 input3)
14311474
/// <typeparam name="T3">The third argument type of the module's forward() function.</typeparam>
14321475
/// <typeparam name="T4">The fourth argument type of the module's forward() function.</typeparam>
14331476
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1434-
public abstract class Module<T1, T2, T3, T4, TResult> : HookableModule<Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, (T1, T2, T3, T4)?>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, TResult, TResult>>, IModule<T1, T2, T3, T4, TResult>
1477+
public abstract class Module<T1, T2, T3, T4, TResult> : HookableModule<Module<T1, T2, T3, T4, TResult>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, (T1, T2, T3, T4)?>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, TResult, TResult>>, IModule<T1, T2, T3, T4, TResult>
14351478
{
14361479
protected Module(string name) : base(name) { }
14371480
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1451,6 +1494,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4)
14511494
{
14521495
// Call pre-hooks, if available.
14531496

1497+
foreach (var hook in module_pre_hooks.Values) {
1498+
hook(this);
1499+
}
1500+
14541501
foreach (var hook in pre_hooks.Values) {
14551502
var modified = hook(this, input1, input2, input3, input4);
14561503
if (modified.HasValue) {
@@ -1471,6 +1518,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4)
14711518
result = modified;
14721519
}
14731520

1521+
foreach (var hook in module_post_hooks.Values) {
1522+
hook(this);
1523+
}
1524+
14741525
return result;
14751526
}
14761527
}
@@ -1484,7 +1535,7 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4)
14841535
/// <typeparam name="T4">The fourth argument type of the module's forward() function.</typeparam>
14851536
/// <typeparam name="T5">The fifth argument type of the module's forward() function.</typeparam>
14861537
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1487-
public abstract class Module<T1, T2, T3, T4, T5,TResult> : HookableModule<Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, (T1, T2, T3, T4, T5)?>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, TResult>
1538+
public abstract class Module<T1, T2, T3, T4, T5,TResult> : HookableModule<Module<T1, T2, T3, T4, T5,TResult>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, (T1, T2, T3, T4, T5)?>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, TResult>
14881539
{
14891540
protected Module(string name) : base(name) { }
14901541
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1504,6 +1555,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5)
15041555
{
15051556
// Call pre-hooks, if available.
15061557

1558+
foreach (var hook in module_pre_hooks.Values) {
1559+
hook(this);
1560+
}
1561+
15071562
foreach (var hook in pre_hooks.Values) {
15081563
var modified = hook(this, input1, input2, input3, input4, input5);
15091564
if (modified.HasValue) {
@@ -1525,6 +1580,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5)
15251580
result = modified;
15261581
}
15271582

1583+
foreach (var hook in module_post_hooks.Values) {
1584+
hook(this);
1585+
}
1586+
15281587
return result;
15291588
}
15301589
}
@@ -1539,7 +1598,7 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5)
15391598
/// <typeparam name="T5">The fifth argument type of the module's forward() function.</typeparam>
15401599
/// <typeparam name="T6">The sixth argument type of the module's forward() function.</typeparam>
15411600
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1542-
public abstract class Module<T1, T2, T3, T4, T5, T6, TResult> : HookableModule<Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, (T1, T2, T3, T4, T5, T6)?>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, T6, TResult>
1601+
public abstract class Module<T1, T2, T3, T4, T5, T6, TResult> : HookableModule<Module<T1, T2, T3, T4, T5, T6, TResult>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, (T1, T2, T3, T4, T5, T6)?>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, T6, TResult>
15431602
{
15441603
protected Module(string name) : base(name) { }
15451604
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1559,6 +1618,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5, T6 in
15591618
{
15601619
// Call pre-hooks, if available.
15611620

1621+
foreach (var hook in module_pre_hooks.Values) {
1622+
hook(this);
1623+
}
1624+
15621625
foreach (var hook in pre_hooks.Values) {
15631626
var modified = hook(this, input1, input2, input3, input4, input5, input6);
15641627
if (modified.HasValue) {
@@ -1581,6 +1644,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5, T6 in
15811644
result = modified;
15821645
}
15831646

1647+
foreach (var hook in module_post_hooks.Values) {
1648+
hook(this);
1649+
}
1650+
15841651
return result;
15851652
}
15861653
}

test/TorchSharpTest/NN.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6583,5 +6583,45 @@ public void TestModulePostHooks()
65836583
lin1.call(input);
65846584
Assert.Equal(1, counter);
65856585
}
6586+
6587+
[Fact]
6588+
public void TestModulePreHooksGeneric()
6589+
{
6590+
var lin1 = torch.nn.Linear(100, 10);
6591+
var input = torch.randn(32, 100, 100);
6592+
var counter = 0;
6593+
6594+
var pre_hook = (Module<Tensor, Tensor> m) => { counter += 1;};
6595+
6596+
var handle = lin1.register_forward_pre_hook(pre_hook);
6597+
6598+
lin1.call(input);
6599+
Assert.Equal(1, counter);
6600+
6601+
handle.remove();
6602+
6603+
lin1.call(input);
6604+
Assert.Equal(1, counter);
6605+
}
6606+
6607+
[Fact]
6608+
public void TestModulePostHooksGeneric()
6609+
{
6610+
var lin1 = torch.nn.Linear(100, 10);
6611+
var input = torch.randn(32, 100, 100);
6612+
var counter = 0;
6613+
6614+
var hook = (Module<Tensor, Tensor> m) => { counter += 1;};
6615+
6616+
var handle = lin1.register_forward_hook(hook);
6617+
6618+
lin1.call(input);
6619+
Assert.Equal(1, counter);
6620+
6621+
handle.remove();
6622+
6623+
lin1.call(input);
6624+
Assert.Equal(1, counter);
6625+
}
65866626
}
65876627
}

0 commit comments

Comments
 (0)