Skip to content

Commit 50d6a34

Browse files
Add a input-free version of pre- and post-hooks that can manipulate the module itself.
1 parent e1dd136 commit 50d6a34

File tree

2 files changed

+116
-9
lines changed

2 files changed

+116
-9
lines changed

src/TorchSharp/NN/Module.cs

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,7 @@ public interface IModule<T1, T2, T3, T4, T5, T6, TResult>
12161216
/// <summary>
12171217
/// Represents a module that accepts 'hook' to the module logic.
12181218
/// </summary>
1219-
public class HookableModule<TPreHook,TPostHook> : Module
1219+
public class HookableModule<TModule,TPreHook,TPostHook> : Module
12201220
{
12211221
protected HookableModule(string name) : base(name) { }
12221222

@@ -1237,22 +1237,41 @@ public HookRemover register_forward_pre_hook(TPreHook hook)
12371237
return new HookRemover(this, key);
12381238
}
12391239

1240+
public HookRemover register_forward_hook(Action<TModule> hook)
1241+
{
1242+
var key = Guid.NewGuid().ToString();
1243+
module_post_hooks.Add(key, hook);
1244+
return new HookRemover(this, key);
1245+
}
1246+
1247+
public HookRemover register_forward_pre_hook(Action<TModule> hook)
1248+
{
1249+
var key = Guid.NewGuid().ToString();
1250+
module_pre_hooks.Add(key, hook);
1251+
return new HookRemover(this, key);
1252+
}
1253+
12401254
private void remove(string key)
12411255
{
12421256
if (pre_hooks.ContainsKey(key)) pre_hooks.Remove(key);
12431257
if (post_hooks.ContainsKey(key)) post_hooks.Remove(key);
1258+
if (module_pre_hooks.ContainsKey(key)) module_pre_hooks.Remove(key);
1259+
if (module_post_hooks.ContainsKey(key)) module_post_hooks.Remove(key);
12441260
}
12451261

12461262
protected Dictionary<string, TPreHook> pre_hooks = new Dictionary<string, TPreHook>();
12471263
protected Dictionary<string, TPostHook> post_hooks = new Dictionary<string, TPostHook>();
12481264

1265+
protected Dictionary<string, Action<TModule>> module_pre_hooks = new Dictionary<string, Action<TModule>>();
1266+
protected Dictionary<string, Action<TModule>> module_post_hooks = new Dictionary<string, Action<TModule>>();
1267+
12491268
/// <summary>
12501269
/// Used to remove a specific hook, following the PyTorch API design.
12511270
/// </summary>
12521271
/// <remarks>The name and namespace of this class is not the same as in PyTorch, but serves the same purpose.</remarks>
12531272
public class HookRemover
12541273
{
1255-
public HookRemover(HookableModule<TPreHook, TPostHook> module, string key)
1274+
public HookRemover(HookableModule<TModule,TPreHook, TPostHook> module, string key)
12561275
{
12571276
this.module = module;
12581277
this.key = key;
@@ -1263,7 +1282,7 @@ public void remove()
12631282
module.remove(key);
12641283
}
12651284

1266-
private HookableModule<TPreHook, TPostHook> module;
1285+
private HookableModule<TModule,TPreHook, TPostHook> module;
12671286
private string key;
12681287
}
12691288
}
@@ -1273,7 +1292,7 @@ public void remove()
12731292
/// </summary>
12741293
/// <typeparam name="T">The argument type of the module's forward() function.</typeparam>
12751294
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1276-
public abstract class Module<T, TResult> : HookableModule<Func<Module<T,TResult>, T, T>, Func<Module<T, TResult>, T, TResult, TResult>>, IModule<T, TResult>
1295+
public abstract class Module<T, TResult> : HookableModule<Module<T, TResult>, Func<Module<T,TResult>, T, T>, Func<Module<T, TResult>, T, TResult, TResult>>, IModule<T, TResult>
12771296
{
12781297
protected Module(string name) : base(name) { }
12791298
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1293,6 +1312,10 @@ public TResult call(T input)
12931312
{
12941313
// Call pre-hooks, if available.
12951314

1315+
foreach (var hook in module_pre_hooks.Values) {
1316+
hook(this);
1317+
}
1318+
12961319
foreach (var hook in pre_hooks.Values) {
12971320
var modified = hook(this, input);
12981321
if (modified is not null)
@@ -1309,6 +1332,10 @@ public TResult call(T input)
13091332
result = modified;
13101333
}
13111334

1335+
foreach (var hook in module_post_hooks.Values) {
1336+
hook(this);
1337+
}
1338+
13121339
return result;
13131340
}
13141341
}
@@ -1319,7 +1346,7 @@ public TResult call(T input)
13191346
/// <typeparam name="T1">The first argument type of the module's forward() function.</typeparam>
13201347
/// <typeparam name="T2">The second argument type of the module's forward() function.</typeparam>
13211348
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1322-
public abstract class Module<T1, T2, TResult> : HookableModule<Func<Module<T1, T2, TResult>, T1, T2, (T1, T2)?>, Func<Module<T1, T2, TResult>, T1, T2, TResult, TResult>>, IModule<T1, T2, TResult>
1349+
public abstract class Module<T1, T2, TResult> : HookableModule<Module<T1, T2, TResult>, Func<Module<T1, T2, TResult>, T1, T2, (T1, T2)?>, Func<Module<T1, T2, TResult>, T1, T2, TResult, TResult>>, IModule<T1, T2, TResult>
13231350
{
13241351
protected Module(string name) : base(name) { }
13251352
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1339,6 +1366,10 @@ public TResult call(T1 input1, T2 input2)
13391366
{
13401367
// Call pre-hooks, if available.
13411368

1369+
foreach (var hook in module_pre_hooks.Values) {
1370+
hook(this);
1371+
}
1372+
13421373
foreach (var hook in pre_hooks.Values) {
13431374
var modified = hook(this, input1, input2);
13441375
if (modified.HasValue) {
@@ -1357,6 +1388,10 @@ public TResult call(T1 input1, T2 input2)
13571388
result = modified;
13581389
}
13591390

1391+
foreach (var hook in module_post_hooks.Values) {
1392+
hook(this);
1393+
}
1394+
13601395
return result;
13611396
}
13621397
}
@@ -1368,7 +1403,7 @@ public TResult call(T1 input1, T2 input2)
13681403
/// <typeparam name="T2">The second argument type of the module's forward() function.</typeparam>
13691404
/// <typeparam name="T3">The third argument type of the module's forward() function.</typeparam>
13701405
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1371-
public abstract class Module<T1, T2, T3, TResult> : HookableModule<Func<Module<T1, T2, T3, TResult>, T1, T2, T3, (T1, T2, T3)?>, Func<Module<T1, T2, T3, TResult>, T1, T2, T3, TResult, TResult>>, IModule<T1, T2, T3, TResult>
1406+
public abstract class Module<T1, T2, T3, TResult> : HookableModule<Module<T1, T2, T3, TResult>,Func<Module<T1, T2, T3, TResult>, T1, T2, T3, (T1, T2, T3)?>, Func<Module<T1, T2, T3, TResult>, T1, T2, T3, TResult, TResult>>, IModule<T1, T2, T3, TResult>
13721407
{
13731408
protected Module(string name) : base(name) { }
13741409
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1388,6 +1423,10 @@ public TResult call(T1 input1, T2 input2, T3 input3)
13881423
{
13891424
// Call pre-hooks, if available.
13901425

1426+
foreach (var hook in module_pre_hooks.Values) {
1427+
hook(this);
1428+
}
1429+
13911430
foreach (var hook in pre_hooks.Values) {
13921431
var modified = hook(this, input1, input2, input3);
13931432
if (modified.HasValue) {
@@ -1407,6 +1446,10 @@ public TResult call(T1 input1, T2 input2, T3 input3)
14071446
result = modified;
14081447
}
14091448

1449+
foreach (var hook in module_post_hooks.Values) {
1450+
hook(this);
1451+
}
1452+
14101453
return result;
14111454
}
14121455
}
@@ -1419,7 +1462,7 @@ public TResult call(T1 input1, T2 input2, T3 input3)
14191462
/// <typeparam name="T3">The third argument type of the module's forward() function.</typeparam>
14201463
/// <typeparam name="T4">The fourth argument type of the module's forward() function.</typeparam>
14211464
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1422-
public abstract class Module<T1, T2, T3, T4, TResult> : HookableModule<Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, (T1, T2, T3, T4)?>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, TResult, TResult>>, IModule<T1, T2, T3, T4, TResult>
1465+
public abstract class Module<T1, T2, T3, T4, TResult> : HookableModule<Module<T1, T2, T3, T4, TResult>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, (T1, T2, T3, T4)?>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, TResult, TResult>>, IModule<T1, T2, T3, T4, TResult>
14231466
{
14241467
protected Module(string name) : base(name) { }
14251468
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1439,6 +1482,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4)
14391482
{
14401483
// Call pre-hooks, if available.
14411484

1485+
foreach (var hook in module_pre_hooks.Values) {
1486+
hook(this);
1487+
}
1488+
14421489
foreach (var hook in pre_hooks.Values) {
14431490
var modified = hook(this, input1, input2, input3, input4);
14441491
if (modified.HasValue) {
@@ -1459,6 +1506,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4)
14591506
result = modified;
14601507
}
14611508

1509+
foreach (var hook in module_post_hooks.Values) {
1510+
hook(this);
1511+
}
1512+
14621513
return result;
14631514
}
14641515
}
@@ -1472,7 +1523,7 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4)
14721523
/// <typeparam name="T4">The fourth argument type of the module's forward() function.</typeparam>
14731524
/// <typeparam name="T5">The fifth argument type of the module's forward() function.</typeparam>
14741525
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1475-
public abstract class Module<T1, T2, T3, T4, T5,TResult> : HookableModule<Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, (T1, T2, T3, T4, T5)?>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, TResult>
1526+
public abstract class Module<T1, T2, T3, T4, T5,TResult> : HookableModule<Module<T1, T2, T3, T4, T5,TResult>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, (T1, T2, T3, T4, T5)?>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, TResult>
14761527
{
14771528
protected Module(string name) : base(name) { }
14781529
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1492,6 +1543,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5)
14921543
{
14931544
// Call pre-hooks, if available.
14941545

1546+
foreach (var hook in module_pre_hooks.Values) {
1547+
hook(this);
1548+
}
1549+
14951550
foreach (var hook in pre_hooks.Values) {
14961551
var modified = hook(this, input1, input2, input3, input4, input5);
14971552
if (modified.HasValue) {
@@ -1513,6 +1568,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5)
15131568
result = modified;
15141569
}
15151570

1571+
foreach (var hook in module_post_hooks.Values) {
1572+
hook(this);
1573+
}
1574+
15161575
return result;
15171576
}
15181577
}
@@ -1527,7 +1586,7 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5)
15271586
/// <typeparam name="T5">The fifth argument type of the module's forward() function.</typeparam>
15281587
/// <typeparam name="T6">The sixth argument type of the module's forward() function.</typeparam>
15291588
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1530-
public abstract class Module<T1, T2, T3, T4, T5, T6, TResult> : HookableModule<Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, (T1, T2, T3, T4, T5, T6)?>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, T6, TResult>
1589+
public abstract class Module<T1, T2, T3, T4, T5, T6, TResult> : HookableModule<Module<T1, T2, T3, T4, T5, T6, TResult>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, (T1, T2, T3, T4, T5, T6)?>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, T6, TResult>
15311590
{
15321591
protected Module(string name) : base(name) { }
15331592
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1547,6 +1606,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5, T6 in
15471606
{
15481607
// Call pre-hooks, if available.
15491608

1609+
foreach (var hook in module_pre_hooks.Values) {
1610+
hook(this);
1611+
}
1612+
15501613
foreach (var hook in pre_hooks.Values) {
15511614
var modified = hook(this, input1, input2, input3, input4, input5, input6);
15521615
if (modified.HasValue) {
@@ -1569,6 +1632,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5, T6 in
15691632
result = modified;
15701633
}
15711634

1635+
foreach (var hook in module_post_hooks.Values) {
1636+
hook(this);
1637+
}
1638+
15721639
return result;
15731640
}
15741641
}

test/TorchSharpTest/NN.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6583,5 +6583,45 @@ public void TestModulePostHooks()
65836583
lin1.call(input);
65846584
Assert.Equal(1, counter);
65856585
}
6586+
6587+
[Fact]
6588+
public void TestModulePreHooksGeneric()
6589+
{
6590+
var lin1 = torch.nn.Linear(100, 10);
6591+
var input = torch.randn(32, 100, 100);
6592+
var counter = 0;
6593+
6594+
var pre_hook = (Module<Tensor, Tensor> m) => { counter += 1;};
6595+
6596+
var handle = lin1.register_forward_pre_hook(pre_hook);
6597+
6598+
lin1.call(input);
6599+
Assert.Equal(1, counter);
6600+
6601+
handle.remove();
6602+
6603+
lin1.call(input);
6604+
Assert.Equal(1, counter);
6605+
}
6606+
6607+
[Fact]
6608+
public void TestModulePostHooksGeneric()
6609+
{
6610+
var lin1 = torch.nn.Linear(100, 10);
6611+
var input = torch.randn(32, 100, 100);
6612+
var counter = 0;
6613+
6614+
var hook = (Module<Tensor, Tensor> m) => { counter += 1;};
6615+
6616+
var handle = lin1.register_forward_hook(hook);
6617+
6618+
lin1.call(input);
6619+
Assert.Equal(1, counter);
6620+
6621+
handle.remove();
6622+
6623+
lin1.call(input);
6624+
Assert.Equal(1, counter);
6625+
}
65866626
}
65876627
}

0 commit comments

Comments
 (0)