Skip to content

Commit 2f496aa

Browse files
Merge branch 'main' into unit
2 parents 1f47f44 + d35ee8d commit 2f496aa

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

src/TorchSharp/NN/Module.cs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,7 @@ public interface IModule<T1, T2, T3, T4, T5, T6, TResult>
12321232
/// <summary>
12331233
/// Represents a module that accepts 'hook' to the module logic.
12341234
/// </summary>
1235-
public class HookableModule<TModule,TPreHook,TPostHook> : Module
1235+
public class HookableModule<TPreHook,TPostHook> : Module
12361236
{
12371237
protected HookableModule(string name) : base(name) { }
12381238

@@ -1253,14 +1253,14 @@ public HookRemover register_forward_pre_hook(TPreHook hook)
12531253
return new HookRemover(this, key);
12541254
}
12551255

1256-
public HookRemover register_forward_hook(Action<TModule> hook)
1256+
public HookRemover register_forward_hook(Action<Module> hook)
12571257
{
12581258
var key = Guid.NewGuid().ToString();
12591259
module_post_hooks.Add(key, hook);
12601260
return new HookRemover(this, key);
12611261
}
12621262

1263-
public HookRemover register_forward_pre_hook(Action<TModule> hook)
1263+
public HookRemover register_forward_pre_hook(Action<Module> hook)
12641264
{
12651265
var key = Guid.NewGuid().ToString();
12661266
module_pre_hooks.Add(key, hook);
@@ -1278,16 +1278,16 @@ private void remove(string key)
12781278
protected Dictionary<string, TPreHook> pre_hooks = new Dictionary<string, TPreHook>();
12791279
protected Dictionary<string, TPostHook> post_hooks = new Dictionary<string, TPostHook>();
12801280

1281-
protected Dictionary<string, Action<TModule>> module_pre_hooks = new Dictionary<string, Action<TModule>>();
1282-
protected Dictionary<string, Action<TModule>> module_post_hooks = new Dictionary<string, Action<TModule>>();
1281+
protected Dictionary<string, Action<Module>> module_pre_hooks = new Dictionary<string, Action<Module>>();
1282+
protected Dictionary<string, Action<Module>> module_post_hooks = new Dictionary<string, Action<Module>>();
12831283

12841284
/// <summary>
12851285
/// Used to remove a specific hook, following the PyTorch API design.
12861286
/// </summary>
12871287
/// <remarks>The name and namespace of this class is not the same as in PyTorch, but serves the same purpose.</remarks>
12881288
public class HookRemover
12891289
{
1290-
public HookRemover(HookableModule<TModule,TPreHook, TPostHook> module, string key)
1290+
public HookRemover(HookableModule<TPreHook, TPostHook> module, string key)
12911291
{
12921292
this.module = module;
12931293
this.key = key;
@@ -1298,7 +1298,7 @@ public void remove()
12981298
module.remove(key);
12991299
}
13001300

1301-
private HookableModule<TModule,TPreHook, TPostHook> module;
1301+
private HookableModule<TPreHook, TPostHook> module;
13021302
private string key;
13031303
}
13041304
}
@@ -1308,7 +1308,7 @@ public void remove()
13081308
/// </summary>
13091309
/// <typeparam name="T">The argument type of the module's forward() function.</typeparam>
13101310
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1311-
public abstract class Module<T, TResult> : HookableModule<Module<T, TResult>, Func<Module<T,TResult>, T, T>, Func<Module<T, TResult>, T, TResult, TResult>>, IModule<T, TResult>
1311+
public abstract class Module<T, TResult> : HookableModule<Func<Module<T,TResult>, T, T>, Func<Module<T, TResult>, T, TResult, TResult>>, IModule<T, TResult>
13121312
{
13131313
protected Module(string name) : base(name) { }
13141314
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1362,7 +1362,7 @@ public TResult call(T input)
13621362
/// <typeparam name="T1">The first argument type of the module's forward() function.</typeparam>
13631363
/// <typeparam name="T2">The second argument type of the module's forward() function.</typeparam>
13641364
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1365-
public abstract class Module<T1, T2, TResult> : HookableModule<Module<T1, T2, TResult>, Func<Module<T1, T2, TResult>, T1, T2, (T1, T2)?>, Func<Module<T1, T2, TResult>, T1, T2, TResult, TResult>>, IModule<T1, T2, TResult>
1365+
public abstract class Module<T1, T2, TResult> : HookableModule<Func<Module<T1, T2, TResult>, T1, T2, (T1, T2)?>, Func<Module<T1, T2, TResult>, T1, T2, TResult, TResult>>, IModule<T1, T2, TResult>
13661366
{
13671367
protected Module(string name) : base(name) { }
13681368
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1419,7 +1419,7 @@ public TResult call(T1 input1, T2 input2)
14191419
/// <typeparam name="T2">The second argument type of the module's forward() function.</typeparam>
14201420
/// <typeparam name="T3">The third argument type of the module's forward() function.</typeparam>
14211421
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1422-
public abstract class Module<T1, T2, T3, TResult> : HookableModule<Module<T1, T2, T3, TResult>,Func<Module<T1, T2, T3, TResult>, T1, T2, T3, (T1, T2, T3)?>, Func<Module<T1, T2, T3, TResult>, T1, T2, T3, TResult, TResult>>, IModule<T1, T2, T3, TResult>
1422+
public abstract class Module<T1, T2, T3, TResult> : HookableModule<Func<Module<T1, T2, T3, TResult>, T1, T2, T3, (T1, T2, T3)?>, Func<Module<T1, T2, T3, TResult>, T1, T2, T3, TResult, TResult>>, IModule<T1, T2, T3, TResult>
14231423
{
14241424
protected Module(string name) : base(name) { }
14251425
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1478,7 +1478,7 @@ public TResult call(T1 input1, T2 input2, T3 input3)
14781478
/// <typeparam name="T3">The third argument type of the module's forward() function.</typeparam>
14791479
/// <typeparam name="T4">The fourth argument type of the module's forward() function.</typeparam>
14801480
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1481-
public abstract class Module<T1, T2, T3, T4, TResult> : HookableModule<Module<T1, T2, T3, T4, TResult>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, (T1, T2, T3, T4)?>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, TResult, TResult>>, IModule<T1, T2, T3, T4, TResult>
1481+
public abstract class Module<T1, T2, T3, T4, TResult> : HookableModule<Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, (T1, T2, T3, T4)?>, Func<Module<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, TResult, TResult>>, IModule<T1, T2, T3, T4, TResult>
14821482
{
14831483
protected Module(string name) : base(name) { }
14841484
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1539,7 +1539,7 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4)
15391539
/// <typeparam name="T4">The fourth argument type of the module's forward() function.</typeparam>
15401540
/// <typeparam name="T5">The fifth argument type of the module's forward() function.</typeparam>
15411541
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1542-
public abstract class Module<T1, T2, T3, T4, T5,TResult> : HookableModule<Module<T1, T2, T3, T4, T5,TResult>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, (T1, T2, T3, T4, T5)?>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, TResult>
1542+
public abstract class Module<T1, T2, T3, T4, T5,TResult> : HookableModule<Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, (T1, T2, T3, T4, T5)?>, Func<Module<T1, T2, T3, T4, T5, TResult>, T1, T2, T3, T4, T5, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, TResult>
15431543
{
15441544
protected Module(string name) : base(name) { }
15451545
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }
@@ -1602,7 +1602,7 @@ public TResult call(T1 input1, T2 input2, T3 input3, T4 input4, T5 input5)
16021602
/// <typeparam name="T5">The fifth argument type of the module's forward() function.</typeparam>
16031603
/// <typeparam name="T6">The sixth argument type of the module's forward() function.</typeparam>
16041604
/// <typeparam name="TResult">The return type of the module's forward() function.</typeparam>
1605-
public abstract class Module<T1, T2, T3, T4, T5, T6, TResult> : HookableModule<Module<T1, T2, T3, T4, T5, T6, TResult>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, (T1, T2, T3, T4, T5, T6)?>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, T6, TResult>
1605+
public abstract class Module<T1, T2, T3, T4, T5, T6, TResult> : HookableModule<Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, (T1, T2, T3, T4, T5, T6)?>, Func<Module<T1, T2, T3, T4, T5, T6, TResult>, T1, T2, T3, T4, T5, T6, TResult, TResult>>, IModule<T1, T2, T3, T4, T5, T6, TResult>
16061606
{
16071607
protected Module(string name) : base(name) { }
16081608
protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { }

src/TorchSharp/Utils/TensorAccessor.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,16 @@ public void CopyTo(T[] array, int arrayIndex = 0, long tensorIndex = 0)
241241
}
242242
}
243243

244+
public void CopyTo(Span<T> array, int arrayIndex = 0, long tensorIndex = 0)
245+
{
246+
int idx = arrayIndex;
247+
foreach (int offset in GetSubsequentIndices(tensorIndex)) {
248+
if (idx >= array.Length) break;
249+
unsafe { array[idx] = ((T*)_tensor_data_ptr)[offset]; }
250+
idx += 1;
251+
}
252+
}
253+
244254
public void CopyFrom(T[] array, int arrayIndex = 0, long tensorIndex = 0)
245255
{
246256
int idx = arrayIndex;
@@ -251,6 +261,16 @@ public void CopyFrom(T[] array, int arrayIndex = 0, long tensorIndex = 0)
251261
}
252262
}
253263

264+
public void CopyFrom(ReadOnlySpan<T> array, int arrayIndex = 0, long tensorIndex = 0)
265+
{
266+
int idx = arrayIndex;
267+
foreach (int offset in GetSubsequentIndices(tensorIndex)) {
268+
if (idx >= array.Length) break;
269+
unsafe { ((T*)_tensor_data_ptr)[offset] = array[idx]; }
270+
idx += 1;
271+
}
272+
}
273+
254274
/// <summary>
255275
/// Translates a linear index within the span represented by the accessor to a linear index
256276
/// used by the underlying tensor. The two should only be different if the tensor is a view

test/TorchSharpTest/NN.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6726,7 +6726,7 @@ public void TestModulePreHooksGeneric()
67266726
var input = torch.randn(32, 100, 100);
67276727
var counter = 0;
67286728

6729-
var pre_hook = (Module<Tensor, Tensor> m) => { counter += 1;};
6729+
var pre_hook = (Module m) => { counter += 1;};
67306730

67316731
var handle = lin1.register_forward_pre_hook(pre_hook);
67326732

@@ -6746,7 +6746,7 @@ public void TestModulePostHooksGeneric()
67466746
var input = torch.randn(32, 100, 100);
67476747
var counter = 0;
67486748

6749-
var hook = (Module<Tensor, Tensor> m) => { counter += 1;};
6749+
var hook = (Module m) => { counter += 1;};
67506750

67516751
var handle = lin1.register_forward_hook(hook);
67526752

0 commit comments

Comments
 (0)