@@ -1216,7 +1216,7 @@ public interface IModule<T1, T2, T3, T4, T5, T6, TResult>
1216
1216
/// <summary>
1217
1217
/// Represents a module that accepts 'hook' to the module logic.
1218
1218
/// </summary>
1219
- public class HookableModule < TPreHook , TPostHook > : Module
1219
+ public class HookableModule < TModule , TPreHook , TPostHook > : Module
1220
1220
{
1221
1221
protected HookableModule ( string name ) : base ( name ) { }
1222
1222
@@ -1237,22 +1237,41 @@ public HookRemover register_forward_pre_hook(TPreHook hook)
1237
1237
return new HookRemover ( this , key ) ;
1238
1238
}
1239
1239
1240
+ public HookRemover register_forward_hook ( Action < TModule > hook )
1241
+ {
1242
+ var key = Guid . NewGuid ( ) . ToString ( ) ;
1243
+ module_post_hooks . Add ( key , hook ) ;
1244
+ return new HookRemover ( this , key ) ;
1245
+ }
1246
+
1247
+ public HookRemover register_forward_pre_hook ( Action < TModule > hook )
1248
+ {
1249
+ var key = Guid . NewGuid ( ) . ToString ( ) ;
1250
+ module_pre_hooks . Add ( key , hook ) ;
1251
+ return new HookRemover ( this , key ) ;
1252
+ }
1253
+
1240
1254
private void remove ( string key )
1241
1255
{
1242
1256
if ( pre_hooks . ContainsKey ( key ) ) pre_hooks . Remove ( key ) ;
1243
1257
if ( post_hooks . ContainsKey ( key ) ) post_hooks . Remove ( key ) ;
1258
+ if ( module_pre_hooks . ContainsKey ( key ) ) module_pre_hooks . Remove ( key ) ;
1259
+ if ( module_post_hooks . ContainsKey ( key ) ) module_post_hooks . Remove ( key ) ;
1244
1260
}
1245
1261
1246
1262
protected Dictionary < string , TPreHook > pre_hooks = new Dictionary < string , TPreHook > ( ) ;
1247
1263
protected Dictionary < string , TPostHook > post_hooks = new Dictionary < string , TPostHook > ( ) ;
1248
1264
1265
+ protected Dictionary < string , Action < TModule > > module_pre_hooks = new Dictionary < string , Action < TModule > > ( ) ;
1266
+ protected Dictionary < string , Action < TModule > > module_post_hooks = new Dictionary < string , Action < TModule > > ( ) ;
1267
+
1249
1268
/// <summary>
1250
1269
/// Used to remove a specific hook, following the PyTorch API design.
1251
1270
/// </summary>
1252
1271
/// <remarks>The name and namespace of this class is not the same as in PyTorch, but serves the same purpose.</remarks>
1253
1272
public class HookRemover
1254
1273
{
1255
- public HookRemover ( HookableModule < TPreHook , TPostHook > module , string key )
1274
+ public HookRemover ( HookableModule < TModule , TPreHook , TPostHook > module , string key )
1256
1275
{
1257
1276
this . module = module ;
1258
1277
this . key = key ;
@@ -1263,7 +1282,7 @@ public void remove()
1263
1282
module . remove ( key ) ;
1264
1283
}
1265
1284
1266
- private HookableModule < TPreHook , TPostHook > module ;
1285
+ private HookableModule < TModule , TPreHook , TPostHook > module ;
1267
1286
private string key ;
1268
1287
}
1269
1288
}
@@ -1273,7 +1292,7 @@ public void remove()
1273
1292
/// </summary>
1274
1293
/// <typeparam name="T">The argument type of the module's forward() function.</typeparam>
1275
1294
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1276
- public abstract class Module < T , TResult > : HookableModule < Func < Module < T , TResult > , T , T > , Func < Module < T , TResult > , T , TResult , TResult > > , IModule < T , TResult >
1295
+ public abstract class Module < T , TResult > : HookableModule < Module < T , TResult > , Func < Module < T , TResult > , T , T > , Func < Module < T , TResult > , T , TResult , TResult > > , IModule < T , TResult >
1277
1296
{
1278
1297
protected Module ( string name ) : base ( name ) { }
1279
1298
protected Module ( IntPtr handle , IntPtr boxedHandle ) : base ( handle , boxedHandle ) { }
@@ -1293,6 +1312,10 @@ public TResult call(T input)
1293
1312
{
1294
1313
// Call pre-hooks, if available.
1295
1314
1315
+ foreach ( var hook in module_pre_hooks . Values ) {
1316
+ hook ( this ) ;
1317
+ }
1318
+
1296
1319
foreach ( var hook in pre_hooks . Values ) {
1297
1320
var modified = hook ( this , input ) ;
1298
1321
if ( modified is not null )
@@ -1309,6 +1332,10 @@ public TResult call(T input)
1309
1332
result = modified ;
1310
1333
}
1311
1334
1335
+ foreach ( var hook in module_post_hooks . Values ) {
1336
+ hook ( this ) ;
1337
+ }
1338
+
1312
1339
return result ;
1313
1340
}
1314
1341
}
@@ -1319,7 +1346,7 @@ public TResult call(T input)
1319
1346
/// <typeparam name="T1">The first argument type of the module's forward() function.</typeparam>
1320
1347
/// <typeparam name="T2">The second argument type of the module's forward() function.</typeparam>
1321
1348
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1322
- public abstract class Module < T1 , T2 , TResult > : HookableModule < Func < Module < T1 , T2 , TResult > , T1 , T2 , ( T1 , T2 ) ? > , Func < Module < T1 , T2 , TResult > , T1 , T2 , TResult , TResult > > , IModule < T1 , T2 , TResult >
1349
+ public abstract class Module < T1 , T2 , TResult > : HookableModule < Module < T1 , T2 , TResult > , Func < Module < T1 , T2 , TResult > , T1 , T2 , ( T1 , T2 ) ? > , Func < Module < T1 , T2 , TResult > , T1 , T2 , TResult , TResult > > , IModule < T1 , T2 , TResult >
1323
1350
{
1324
1351
protected Module ( string name ) : base ( name ) { }
1325
1352
protected Module ( IntPtr handle , IntPtr boxedHandle ) : base ( handle , boxedHandle ) { }
@@ -1339,6 +1366,10 @@ public TResult call(T1 input1, T2 input2)
1339
1366
{
1340
1367
// Call pre-hooks, if available.
1341
1368
1369
+ foreach ( var hook in module_pre_hooks . Values ) {
1370
+ hook ( this ) ;
1371
+ }
1372
+
1342
1373
foreach ( var hook in pre_hooks . Values ) {
1343
1374
var modified = hook ( this , input1 , input2 ) ;
1344
1375
if ( modified . HasValue ) {
@@ -1357,6 +1388,10 @@ public TResult call(T1 input1, T2 input2)
1357
1388
result = modified ;
1358
1389
}
1359
1390
1391
+ foreach ( var hook in module_post_hooks . Values ) {
1392
+ hook ( this ) ;
1393
+ }
1394
+
1360
1395
return result ;
1361
1396
}
1362
1397
}
@@ -1368,7 +1403,7 @@ public TResult call(T1 input1, T2 input2)
1368
1403
/// <typeparam name="T2">The second argument type of the module's forward() function.</typeparam>
1369
1404
/// <typeparam name="T3">The third argument type of the module's forward() function.</typeparam>
1370
1405
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1371
- public abstract class Module < T1 , T2 , T3 , TResult > : HookableModule < Func < Module < T1 , T2 , T3 , TResult > , T1 , T2 , T3 , ( T1 , T2 , T3 ) ? > , Func < Module < T1 , T2 , T3 , TResult > , T1 , T2 , T3 , TResult , TResult > > , IModule < T1 , T2 , T3 , TResult >
1406
+ public abstract class Module < T1 , T2 , T3 , TResult > : HookableModule < Module < T1 , T2 , T3 , TResult > , Func < Module < T1 , T2 , T3 , TResult > , T1 , T2 , T3 , ( T1 , T2 , T3 ) ? > , Func < Module < T1 , T2 , T3 , TResult > , T1 , T2 , T3 , TResult , TResult > > , IModule < T1 , T2 , T3 , TResult >
1372
1407
{
1373
1408
protected Module ( string name ) : base ( name ) { }
1374
1409
protected Module ( IntPtr handle , IntPtr boxedHandle ) : base ( handle , boxedHandle ) { }
@@ -1388,6 +1423,10 @@ public TResult call(T1 input1, T2 input2, T3 input3)
1388
1423
{
1389
1424
// Call pre-hooks, if available.
1390
1425
1426
+ foreach ( var hook in module_pre_hooks . Values ) {
1427
+ hook ( this ) ;
1428
+ }
1429
+
1391
1430
foreach ( var hook in pre_hooks . Values ) {
1392
1431
var modified = hook ( this , input1 , input2 , input3 ) ;
1393
1432
if ( modified . HasValue ) {
@@ -1407,6 +1446,10 @@ public TResult call(T1 input1, T2 input2, T3 input3)
1407
1446
result = modified ;
1408
1447
}
1409
1448
1449
+ foreach ( var hook in module_post_hooks . Values ) {
1450
+ hook ( this ) ;
1451
+ }
1452
+
1410
1453
return result ;
1411
1454
}
1412
1455
}
@@ -1419,7 +1462,7 @@ public TResult call(T1 input1, T2 input2, T3 input3)
1419
1462
/// <typeparam name="T3">The third argument type of the module's forward() function.</typeparam>
1420
1463
/// <typeparam name="T4">The fourth argument type of the module's forward() function.</typeparam>
1421
1464
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1422
- public abstract class Module < T1 , T2 , T3 , T4 , TResult > : HookableModule < Func < Module < T1 , T2 , T3 , T4 , TResult > , T1 , T2 , T3 , T4 , ( T1 , T2 , T3 , T4 ) ? > , Func < Module < T1 , T2 , T3 , T4 , TResult > , T1 , T2 , T3 , T4 , TResult , TResult > > , IModule < T1 , T2 , T3 , T4 , TResult >
1465
+ public abstract class Module < T1 , T2 , T3 , T4 , TResult > : HookableModule < Module < T1 , T2 , T3 , T4 , TResult > , Func < Module < T1 , T2 , T3 , T4 , TResult > , T1 , T2 , T3 , T4 , ( T1 , T2 , T3 , T4 ) ? > , Func < Module < T1 , T2 , T3 , T4 , TResult > , T1 , T2 , T3 , T4 , TResult , TResult > > , IModule < T1 , T2 , T3 , T4 , TResult >
1423
1466
{
1424
1467
protected Module ( string name ) : base ( name ) { }
1425
1468
protected Module ( IntPtr handle , IntPtr boxedHandle ) : base ( handle , boxedHandle ) { }
@@ -1439,6 +1482,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4)
1439
1482
{
1440
1483
// Call pre-hooks, if available.
1441
1484
1485
+ foreach ( var hook in module_pre_hooks . Values ) {
1486
+ hook ( this ) ;
1487
+ }
1488
+
1442
1489
foreach ( var hook in pre_hooks . Values ) {
1443
1490
var modified = hook ( this , input1 , input2 , input3 , input4 ) ;
1444
1491
if ( modified . HasValue ) {
@@ -1459,6 +1506,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4)
1459
1506
result = modified ;
1460
1507
}
1461
1508
1509
+ foreach ( var hook in module_post_hooks . Values ) {
1510
+ hook ( this ) ;
1511
+ }
1512
+
1462
1513
return result ;
1463
1514
}
1464
1515
}
@@ -1472,7 +1523,7 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4)
1472
1523
/// <typeparam name="T4">The fourth argument type of the module's forward() function.</typeparam>
1473
1524
/// <typeparam name="T5">The fifth argument type of the module's forward() function.</typeparam>
1474
1525
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1475
- public abstract class Module < T1 , T2 , T3 , T4 , T5 , TResult > : HookableModule < Func < Module < T1 , T2 , T3 , T4 , T5 , TResult > , T1 , T2 , T3 , T4 , T5 , ( T1 , T2 , T3 , T4 , T5 ) ? > , Func < Module < T1 , T2 , T3 , T4 , T5 , TResult > , T1 , T2 , T3 , T4 , T5 , TResult , TResult > > , IModule < T1 , T2 , T3 , T4 , T5 , TResult >
1526
+ public abstract class Module < T1 , T2 , T3 , T4 , T5 , TResult > : HookableModule < Module < T1 , T2 , T3 , T4 , T5 , TResult > , Func < Module < T1 , T2 , T3 , T4 , T5 , TResult > , T1 , T2 , T3 , T4 , T5 , ( T1 , T2 , T3 , T4 , T5 ) ? > , Func < Module < T1 , T2 , T3 , T4 , T5 , TResult > , T1 , T2 , T3 , T4 , T5 , TResult , TResult > > , IModule < T1 , T2 , T3 , T4 , T5 , TResult >
1476
1527
{
1477
1528
protected Module ( string name ) : base ( name ) { }
1478
1529
protected Module ( IntPtr handle , IntPtr boxedHandle ) : base ( handle , boxedHandle ) { }
@@ -1492,6 +1543,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5)
1492
1543
{
1493
1544
// Call pre-hooks, if available.
1494
1545
1546
+ foreach ( var hook in module_pre_hooks . Values ) {
1547
+ hook ( this ) ;
1548
+ }
1549
+
1495
1550
foreach ( var hook in pre_hooks . Values ) {
1496
1551
var modified = hook ( this , input1 , input2 , input3 , input4 , input5 ) ;
1497
1552
if ( modified . HasValue ) {
@@ -1513,6 +1568,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5)
1513
1568
result = modified ;
1514
1569
}
1515
1570
1571
+ foreach ( var hook in module_post_hooks . Values ) {
1572
+ hook ( this ) ;
1573
+ }
1574
+
1516
1575
return result ;
1517
1576
}
1518
1577
}
@@ -1527,7 +1586,7 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5)
1527
1586
/// <typeparam name="T5">The fifth argument type of the module's forward() function.</typeparam>
1528
1587
/// <typeparam name="T6">The sixth argument type of the module's forward() function.</typeparam>
1529
1588
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1530
- public abstract class Module < T1 , T2 , T3 , T4 , T5 , T6 , TResult > : HookableModule < Func < Module < T1 , T2 , T3 , T4 , T5 , T6 , TResult > , T1 , T2 , T3 , T4 , T5 , T6 , ( T1 , T2 , T3 , T4 , T5 , T6 ) ? > , Func < Module < T1 , T2 , T3 , T4 , T5 , T6 , TResult > , T1 , T2 , T3 , T4 , T5 , T6 , TResult , TResult > > , IModule < T1 , T2 , T3 , T4 , T5 , T6 , TResult >
1589
+ public abstract class Module < T1 , T2 , T3 , T4 , T5 , T6 , TResult > : HookableModule < Module < T1 , T2 , T3 , T4 , T5 , T6 , TResult > , Func < Module < T1 , T2 , T3 , T4 , T5 , T6 , TResult > , T1 , T2 , T3 , T4 , T5 , T6 , ( T1 , T2 , T3 , T4 , T5 , T6 ) ? > , Func < Module < T1 , T2 , T3 , T4 , T5 , T6 , TResult > , T1 , T2 , T3 , T4 , T5 , T6 , TResult , TResult > > , IModule < T1 , T2 , T3 , T4 , T5 , T6 , TResult >
1531
1590
{
1532
1591
protected Module ( string name ) : base ( name ) { }
1533
1592
protected Module ( IntPtr handle , IntPtr boxedHandle ) : base ( handle , boxedHandle ) { }
@@ -1547,6 +1606,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5, T6 in
1547
1606
{
1548
1607
// Call pre-hooks, if available.
1549
1608
1609
+ foreach ( var hook in module_pre_hooks . Values ) {
1610
+ hook ( this ) ;
1611
+ }
1612
+
1550
1613
foreach ( var hook in pre_hooks . Values ) {
1551
1614
var modified = hook ( this , input1 , input2 , input3 , input4 , input5 , input6 ) ;
1552
1615
if ( modified . HasValue ) {
@@ -1569,6 +1632,10 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5, T6 in
1569
1632
result = modified ;
1570
1633
}
1571
1634
1635
+ foreach ( var hook in module_post_hooks . Values ) {
1636
+ hook ( this ) ;
1637
+ }
1638
+
1572
1639
return result ;
1573
1640
}
1574
1641
}
0 commit comments