Skip to content

Commit 254470e

Browse files
Overrides on ParamLessModule
1 parent 25e128c commit 254470e

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

src/TorchSharp/NN/ParamLessModule.cs

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@ namespace TorchSharp
99

1010
namespace Modules
1111
{
12+
public interface IParameterLessModule {
13+
14+
}
1215
/// <summary>
1316
/// Base class for all modules that do not have any tensor parameters or buffers, and
1417
/// for which the `_to()` implementation can therefore be simplified.
1518
/// </summary>
16-
public abstract class ParamLessModule<T1, T2> : nn.Module<T1, T2>
19+
public abstract class ParamLessModule<T1, T2> : nn.Module<T1, T2>, IParameterLessModule
1720
{
1821
protected ParamLessModule(string name) : base(name) { }
1922

@@ -26,13 +29,30 @@ protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxe
2629
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this;
2730

2831
protected internal override nn.Module _to(ScalarType dtype) => this;
32+
33+
public override void register_buffer(string name, Tensor tensor, bool persistent = true)
34+
{
35+
throw new InvalidOperationException($"Cannot register a buffer on a module that is declared 'parameter-less.'");
36+
}
37+
38+
public override void register_parameter(string name, Parameter param)
39+
{
40+
throw new InvalidOperationException($"Cannot register a parameter on a module that is declared 'parameter-less.'");
41+
}
42+
43+
public override void register_module(string name, nn.Module submodule)
44+
{
45+
if (submodule is not IParameterLessModule)
46+
throw new InvalidOperationException($"Submodules of a parameter-less module must also be parameter-less.");
47+
base.register_module(name, submodule);
48+
}
2949
}
3050

3151
/// <summary>
3252
/// Base class for all modules that do not have any tensor parameters or buffers, and
3353
/// for which the `_to()` implementation can therefore be simplified.
3454
/// </summary>
35-
public abstract class ParamLessModule<T1, T2, T3> : nn.Module<T1, T2, T3>
55+
public abstract class ParamLessModule<T1, T2, T3> : nn.Module<T1, T2, T3>, IParameterLessModule
3656
{
3757
protected ParamLessModule(string name) : base(name) { }
3858

@@ -45,13 +65,30 @@ protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxe
4565
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this;
4666

4767
protected internal override nn.Module _to(ScalarType dtype) => this;
68+
69+
public override void register_buffer(string name, Tensor tensor, bool persistent = true)
70+
{
71+
throw new InvalidOperationException($"Cannot register a buffer on a module that is declared 'parameter-less.'");
72+
}
73+
74+
public override void register_parameter(string name, Parameter param)
75+
{
76+
throw new InvalidOperationException($"Cannot register a parameter on a module that is declared 'parameter-less.'");
77+
}
78+
79+
public override void register_module(string name, nn.Module submodule)
80+
{
81+
if (submodule is not IParameterLessModule)
82+
throw new InvalidOperationException($"Submodules of a parameter-less module must also be parameter-less.");
83+
base.register_module(name, submodule);
84+
}
4885
}
4986

5087
/// <summary>
5188
/// Base class for all modules that do not have any tensor parameters or buffers, and
5289
/// for which the `_to()` implementation can therefore be simplified.
5390
/// </summary>
54-
public abstract class ParamLessModule<T1, T2, T3, T4> : nn.Module<T1, T2, T3, T4>
91+
public abstract class ParamLessModule<T1, T2, T3, T4> : nn.Module<T1, T2, T3, T4>, IParameterLessModule
5592
{
5693
protected ParamLessModule(string name) : base(name) { }
5794

@@ -64,6 +101,23 @@ protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxe
64101
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this;
65102

66103
protected internal override nn.Module _to(ScalarType dtype) => this;
104+
105+
public override void register_buffer(string name, Tensor tensor, bool persistent = true)
106+
{
107+
throw new InvalidOperationException($"Cannot register a buffer on a module that is declared 'parameter-less.'");
108+
}
109+
110+
public override void register_parameter(string name, Parameter param)
111+
{
112+
throw new InvalidOperationException($"Cannot register a parameter on a module that is declared 'parameter-less.'");
113+
}
114+
115+
public override void register_module(string name, nn.Module submodule)
116+
{
117+
if (submodule is not IParameterLessModule)
118+
throw new InvalidOperationException($"Submodules of a parameter-less module must also be parameter-less.");
119+
base.register_module(name, submodule);
120+
}
67121
}
68122
}
69123
}

test/TorchSharpTest/NN.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6633,5 +6633,37 @@ public void TestModulePostHooks()
66336633
lin1.call(input);
66346634
Assert.Equal(1, counter);
66356635
}
6636+
6637+
[Fact]
6638+
public void TestCustomParameterLessModule()
6639+
{
6640+
var cnp = new CustomNoParameters("test");
6641+
6642+
// Should not throw
6643+
cnp.register_module("sub", new CustomNoParameters("test"));
6644+
6645+
Assert.True(cnp.named_modules().Count() > 0);
6646+
Assert.Equal("sub", cnp.named_modules().First().name);
6647+
6648+
Assert.Throws<InvalidOperationException>(() => cnp.register_module("test", torch.nn.Linear(10,10, true)));
6649+
Assert.Throws<InvalidOperationException>(() => cnp.register_buffer("test", torch.rand(10)));
6650+
Assert.Throws<InvalidOperationException>(() => cnp.register_parameter("test", new Parameter(torch.rand(10))));
6651+
}
6652+
6653+
class CustomNoParameters : ParamLessModule<Tensor, Tensor>
6654+
{
6655+
public CustomNoParameters(string name) : base(name)
6656+
{
6657+
}
6658+
6659+
public CustomNoParameters(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle)
6660+
{
6661+
}
6662+
6663+
public override Tensor forward(Tensor input)
6664+
{
6665+
throw new NotImplementedException();
6666+
}
6667+
}
66366668
}
66376669
}

0 commit comments

Comments
 (0)