@@ -9,11 +9,14 @@ namespace TorchSharp
9
9
10
10
namespace Modules
11
11
{
12
+ public interface IParameterLessModule {
13
+
14
+ }
12
15
/// <summary>
13
16
/// Base class for all modules that do not have any tensor parameters or buffers, and
14
17
/// for which the `_to()` implementation can therefore be simplified.
15
18
/// </summary>
16
- public abstract class ParamLessModule < T1 , T2 > : nn . Module < T1 , T2 >
19
+ public abstract class ParamLessModule < T1 , T2 > : nn . Module < T1 , T2 > , IParameterLessModule
17
20
{
18
21
protected ParamLessModule ( string name ) : base ( name ) { }
19
22
@@ -26,13 +29,30 @@ protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxe
26
29
protected internal override nn . Module _to ( DeviceType deviceType , int deviceIndex = - 1 ) => this ;
27
30
28
31
protected internal override nn . Module _to ( ScalarType dtype ) => this ;
32
+
33
+ public override void register_buffer ( string name , Tensor tensor , bool persistent = true )
34
+ {
35
+ throw new InvalidOperationException ( $ "Cannot register a buffer on a module that is declared 'parameter-less.'") ;
36
+ }
37
+
38
+ public override void register_parameter ( string name , Parameter param )
39
+ {
40
+ throw new InvalidOperationException ( $ "Cannot register a parameter on a module that is declared 'parameter-less.'") ;
41
+ }
42
+
43
+ public override void register_module ( string name , nn . Module submodule )
44
+ {
45
+ if ( submodule is not IParameterLessModule )
46
+ throw new InvalidOperationException ( $ "Submodules of a parameter-less module must also be parameter-less.") ;
47
+ base . register_module ( name , submodule ) ;
48
+ }
29
49
}
30
50
31
51
/// <summary>
32
52
/// Base class for all modules that do not have any tensor parameters or buffers, and
33
53
/// for which the `_to()` implementation can therefore be simplified.
34
54
/// </summary>
35
- public abstract class ParamLessModule < T1 , T2 , T3 > : nn . Module < T1 , T2 , T3 >
55
+ public abstract class ParamLessModule < T1 , T2 , T3 > : nn . Module < T1 , T2 , T3 > , IParameterLessModule
36
56
{
37
57
protected ParamLessModule ( string name ) : base ( name ) { }
38
58
@@ -45,13 +65,30 @@ protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxe
45
65
protected internal override nn . Module _to ( DeviceType deviceType , int deviceIndex = - 1 ) => this ;
46
66
47
67
protected internal override nn . Module _to ( ScalarType dtype ) => this ;
68
+
69
+ public override void register_buffer ( string name , Tensor tensor , bool persistent = true )
70
+ {
71
+ throw new InvalidOperationException ( $ "Cannot register a buffer on a module that is declared 'parameter-less.'") ;
72
+ }
73
+
74
+ public override void register_parameter ( string name , Parameter param )
75
+ {
76
+ throw new InvalidOperationException ( $ "Cannot register a parameter on a module that is declared 'parameter-less.'") ;
77
+ }
78
+
79
+ public override void register_module ( string name , nn . Module submodule )
80
+ {
81
+ if ( submodule is not IParameterLessModule )
82
+ throw new InvalidOperationException ( $ "Submodules of a parameter-less module must also be parameter-less.") ;
83
+ base . register_module ( name , submodule ) ;
84
+ }
48
85
}
49
86
50
87
/// <summary>
51
88
/// Base class for all modules that do not have any tensor parameters or buffers, and
52
89
/// for which the `_to()` implementation can therefore be simplified.
53
90
/// </summary>
54
- public abstract class ParamLessModule < T1 , T2 , T3 , T4 > : nn . Module < T1 , T2 , T3 , T4 >
91
+ public abstract class ParamLessModule < T1 , T2 , T3 , T4 > : nn . Module < T1 , T2 , T3 , T4 > , IParameterLessModule
55
92
{
56
93
protected ParamLessModule ( string name ) : base ( name ) { }
57
94
@@ -64,6 +101,23 @@ protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxe
64
101
protected internal override nn . Module _to ( DeviceType deviceType , int deviceIndex = - 1 ) => this ;
65
102
66
103
protected internal override nn . Module _to ( ScalarType dtype ) => this ;
104
+
105
+ public override void register_buffer ( string name , Tensor tensor , bool persistent = true )
106
+ {
107
+ throw new InvalidOperationException ( $ "Cannot register a buffer on a module that is declared 'parameter-less.'") ;
108
+ }
109
+
110
+ public override void register_parameter ( string name , Parameter param )
111
+ {
112
+ throw new InvalidOperationException ( $ "Cannot register a parameter on a module that is declared 'parameter-less.'") ;
113
+ }
114
+
115
+ public override void register_module ( string name , nn . Module submodule )
116
+ {
117
+ if ( submodule is not IParameterLessModule )
118
+ throw new InvalidOperationException ( $ "Submodules of a parameter-less module must also be parameter-less.") ;
119
+ base . register_module ( name , submodule ) ;
120
+ }
67
121
}
68
122
}
69
123
}
0 commit comments