Skip to content

Commit 59d0d37

Browse files
Module hooks don't depend on the forward() signature.
Add CopyTo/From overloads to TensorAccessor for Span<T>.
1 parent 6778374 commit 59d0d37

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

src/TorchSharp/NN/Module.cs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,7 @@ public interface IModule<T1, T2, T3, T4, T5, T6, TResult>
12301230
/// <summary>
12311231
/// Represents a module that accepts 'hook' to the module logic.
12321232
/// </summary>
1233-
public class HookableModule<TModule,TPreHook,TPostHook> : Module
1233+
public class HookableModule<TPreHook,TPostHook> : Module
12341234
{
12351235
protected HookableModule(string name) : base(name) { }
12361236

@@ -1251,14 +1251,14 @@ public HookRemover register_forward_pre_hook(TPreHook hook)
12511251
return new HookRemover(this, key);
12521252
}
12531253

1254-
public HookRemover register_forward_hook(Action<TModule> hook)
1254+
public HookRemover register_forward_hook(Action<Module> hook)
12551255
{
12561256
var key = Guid.NewGuid().ToString();
12571257
module_post_hooks.Add(key, hook);
12581258
return new HookRemover(this, key);
12591259
}
12601260

1261-
public HookRemover register_forward_pre_hook(Action<TModule> hook)
1261+
public HookRemover register_forward_pre_hook(Action<Module> hook)
12621262
{
12631263
var key = Guid.NewGuid().ToString();
12641264
module_pre_hooks.Add(key, hook);
@@ -1276,16 +1276,16 @@ private void remove(string key)
12761276
protected Dictionary<string, TPreHook> pre_hooks = new Dictionary<string, TPreHook>();
12771277
protected Dictionary<string, TPostHook> post_hooks = new Dictionary<string, TPostHook>();
12781278

1279-
protected Dictionary<string, Action<TModule>> module_pre_hooks = new Dictionary<string, Action<TModule>>();
1280-
protected Dictionary<string, Action<TModule>> module_post_hooks = new Dictionary<string, Action<TModule>>();
1279+
protected Dictionary<string, Action<Module>> module_pre_hooks = new Dictionary<string, Action<Module>>();
1280+
protected Dictionary<string, Action<Module>> module_post_hooks = new Dictionary<string, Action<Module>>();
12811281

12821282
/// <summary>
12831283
/// Used to remove a specific hook, following the PyTorch API design.
12841284
/// </summary>
12851285
/// <remarks>The name and namespace of this class is not the same as in PyTorch, but serves the same purpose.</remarks>
12861286
public class HookRemover
12871287
{
1288-
public HookRemover(HookableModule<TModule,TPreHook, TPostHook> module, string key)
1288+
public HookRemover(HookableModule<TPreHook, TPostHook> module, string key)
12891289
{
12901290
this.module = module;
12911291
this.key = key;
@@ -1296,7 +1296,7 @@ public void remove()
12961296
module.remove(key);
12971297
}
12981298

1299-
private HookableModule<TModule,TPreHook, TPostHook> module;
1299+
private HookableModule<TPreHook, TPostHook> module;
13001300
private string key;
13011301
}
13021302
}
@@ -1306,7 +1306,7 @@ public void remove()
13061306
/// </summary>
13071307
/// <typeparam name="T">The argument type of the module's forward() function.</typeparam>
13081308
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1309-
public abstract class Module<T, TResult> : HookableModule<Module<T, TResult>, Func<Module<T,TResult>, T, T>, Func<Module<T, TResult>, T, TResult, TResult>>, IModule<T, TResult>
1309+
public abstract class Module<T, TResult> : HookableModule<Func<Module<T,TResult>, T, T>, Func<Module<T, TResult>, T, TResult, TResult>>, IModule<T, TResult>
13101310
{
13111311
protected Module(string name) : base(name) { }
13121312
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1360,7 +1360,7 @@ public TResult call(T input)
13601360
/// <typeparam name="T1">The first argument type of the module's forward() function.</typeparam>
13611361
/// <typeparam name="T2">The second argument type of the module's forward() function.</typeparam>
13621362
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1363-
public abstract class Module<T1, T2, TResult> : HookableModule<Module<T1, T2, TResult>, Func<Module<T1, T2, TResult>, T1, T2, (T1, T2)?>, Func<Module<T1, T2, TResult>, T1, T2, TResult, TResult>>, IModule<T1, T2, TResult>
1363+
public abstract class Module<T1, T2, TResult> : HookableModule<Func<Module<T1, T2, TResult>, T1, T2, (T1, T2)?>, Func<Module<T1, T2, TResult>, T1, T2, TResult, TResult>>, IModule<T1, T2, TResult>
13641364
{
13651365
protected Module(string name) : base(name) { }
13661366
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1417,7 +1417,7 @@ public TResult call(T1 input1, T2 input2)
14171417
/// <typeparam name="T2">The second argument type of the module's forward() function.</typeparam>
14181418
/// <typeparam name="T3">The third argument type of the module's forward() function.</typeparam>
14191419
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1420-
public abstract class Module<T1, T2, T3, TResult> : HookableModule<Module<T1, T2, T3, TResult>,Func<Module<T1, T2, T3, TResult>, T1, T2, T3, (T1, T2, T3)?>, Func<Module<T1, T2, T3, TResult>, T1, T2, T3, TResult, TResult>>, IModule<T1, T2, T3, TResult>
1420+
public abstract class Module<T1, T2, T3, TResult> : HookableModule<Func<Module<T1, T2, T3, TResult>, T1, T2, T3, (T1, T2, T3)?>, Func<Module<T1, T2, T3, TResult>, T1, T2, T3, TResult, TResult>>, IModule<T1, T2, T3, TResult>
14211421
{
14221422
protected Module(string name) : base(name) { }
14231423
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1476,7 +1476,7 @@ public TResult call(T1 input1, T2 input2, T3 input3)
14761476
/// <typeparam name="T3">The third argument type of the module's forward() function.</typeparam>
14771477
/// <typeparam name="T4">The fourth argument type of the module's forward() function.</typeparam>
14781478
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1479-
public abstract class Module<T1, T2, T3, T4, TResult> : HookableModule<Module<T1, T2, T3, T4, TResult>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, (T1, T2, T3, T4)?>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, TResult, TResult>>, IModule<T1, T2, T3, T4, TResult>
1479+
public abstract class Module<T1, T2, T3, T4, TResult> : HookableModule<Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, (T1, T2, T3, T4)?>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, TResult, TResult>>, IModule<T1, T2, T3, T4, TResult>
14801480
{
14811481
protected Module(string name) : base(name) { }
14821482
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1537,7 +1537,7 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4)
15371537
/// <typeparam name="T4">The fourth argument type of the module's forward() function.</typeparam>
15381538
/// <typeparam name="T5">The fifth argument type of the module's forward() function.</typeparam>
15391539
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1540-
public abstract class Module<T1, T2, T3, T4, T5,TResult> : HookableModule<Module<T1, T2, T3, T4, T5,TResult>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, (T1, T2, T3, T4, T5)?>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, TResult>
1540+
public abstract class Module<T1, T2, T3, T4, T5,TResult> : HookableModule<Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, (T1, T2, T3, T4, T5)?>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, TResult>
15411541
{
15421542
protected Module(string name) : base(name) { }
15431543
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1600,7 +1600,7 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5)
16001600
/// <typeparam name="T5">The fifth argument type of the module's forward() function.</typeparam>
16011601
/// <typeparam name="T6">The sixth argument type of the module's forward() function.</typeparam>
16021602
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1603-
public abstract class Module<T1, T2, T3, T4, T5, T6, TResult> : HookableModule<Module<T1, T2, T3, T4, T5, T6, TResult>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, (T1, T2, T3, T4, T5, T6)?>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, T6, TResult>
1603+
public abstract class Module<T1, T2, T3, T4, T5, T6, TResult> : HookableModule<Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, (T1, T2, T3, T4, T5, T6)?>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, T6, TResult>
16041604
{
16051605
protected Module(string name) : base(name) { }
16061606
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }

src/TorchSharp/Utils/TensorAccessor.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,16 @@ public void CopyTo(T[] array, int arrayIndex = 0, long tensorIndex = 0)
241241
}
242242
}
243243

244+
public void CopyTo(Span<T> array, int arrayIndex = 0, long tensorIndex = 0)
245+
{
246+
int idx = arrayIndex;
247+
foreach (int offset in GetSubsequentIndices(tensorIndex)) {
248+
if (idx >= array.Length) break;
249+
unsafe { array[idx] = ((T*)_tensor_data_ptr)[offset]; }
250+
idx += 1;
251+
}
252+
}
253+
244254
public void CopyFrom(T[] array, int arrayIndex = 0, long tensorIndex = 0)
245255
{
246256
int idx = arrayIndex;
@@ -251,6 +261,16 @@ public void CopyFrom(T[] array, int arrayIndex = 0, long tensorIndex = 0)
251261
}
252262
}
253263

264+
public void CopyFrom(ReadOnlySpan<T> array, int arrayIndex = 0, long tensorIndex = 0)
265+
{
266+
int idx = arrayIndex;
267+
foreach (int offset in GetSubsequentIndices(tensorIndex)) {
268+
if (idx >= array.Length) break;
269+
unsafe { ((T*)_tensor_data_ptr)[offset] = array[idx]; }
270+
idx += 1;
271+
}
272+
}
273+
254274
/// <summary>
255275
/// Translates a linear index within the span represented by the accessor to a linear index
256276
/// used by the underlying tensor. The two should only be different if the tensor is a view

test/TorchSharpTest/NN.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6591,7 +6591,7 @@ public void TestModulePreHooksGeneric()
65916591
var input = torch.randn(32, 100, 100);
65926592
var counter = 0;
65936593

6594-
var pre_hook = (Module<Tensor, Tensor> m) => { counter += 1;};
6594+
var pre_hook = (Module m) => { counter += 1;};
65956595

65966596
var handle = lin1.register_forward_pre_hook(pre_hook);
65976597

@@ -6611,7 +6611,7 @@ public void TestModulePostHooksGeneric()
66116611
var input = torch.randn(32, 100, 100);
66126612
var counter = 0;
66136613

6614-
var hook = (Module<Tensor, Tensor> m) => { counter += 1;};
6614+
var hook = (Module m) => { counter += 1;};
66156615

66166616
var handle = lin1.register_forward_hook(hook);
66176617

0 commit comments

Comments
 (0)