@@ -8,84 +8,27 @@ namespace TorchSharp
8
8
{
9
9
using Modules ;
10
10
11
- public enum PaddingModes
12
- {
13
- Zeros = 0 ,
14
- Reflect = 1 ,
15
- Replicate = 2 ,
16
- Circular = 3 ,
17
- Constant = 4 ,
18
- }
19
-
20
- public enum Padding
21
- {
22
- Valid = 0 ,
23
- Same = 1
24
- }
25
-
26
11
namespace Modules
27
12
{
28
- public abstract class Convolution : torch . nn . Module < Tensor , Tensor >
29
- {
30
- protected Convolution ( IntPtr handle , IntPtr boxedHandle , long input_channels ) : base ( handle , boxedHandle )
31
- {
32
- this . input_channels = input_channels ;
33
- }
34
-
35
- protected bool ValidateShape ( Tensor input , long dimensions )
36
- {
37
- var shape = input . shape ;
38
- var ndim = shape . LongLength ;
39
-
40
- return ( ndim == dimensions + 2 ) && ( input . shape [ 1 ] == input_channels ) || // Batched: N + C + dims
41
- ( ndim == dimensions + 1 && input . shape [ 0 ] == input_channels ) ; // Unbathced: C + dims
42
-
43
- }
44
-
45
- protected long input_channels ;
46
- }
47
-
48
13
public sealed class Conv1d : Convolution
49
14
{
50
- internal Conv1d ( IntPtr handle , IntPtr boxedHandle , long input_channels ) : base ( handle , boxedHandle , input_channels ) { }
15
+ internal Conv1d ( long in_channels , long out_channels , long kernel_size , long stride , long ? padding , Padding ? padding_type , long dilation , long groups = 1 , bool bias = true , PaddingModes padding_mode = PaddingModes . Zeros , torch . Device ? device = null , ScalarType ? dtype = null )
16
+ : base ( nameof ( Conv1d ) , in_channels , out_channels , new [ ] { kernel_size } , new [ ] { stride } , padding . HasValue ? new [ ] { padding . Value } : null , padding_type , new [ ] { dilation } , false , new [ ] { 0L } , groups , bias , padding_mode , device , dtype ) { }
51
17
52
18
public override Tensor forward ( Tensor input )
53
19
{
54
- if ( ValidateShape ( input , 1 ) ) {
55
- var res = THSNN_Conv1d_forward ( handle , input . Handle ) ;
56
- if ( res == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
57
- return new Tensor ( res ) ;
58
- }
59
- throw new ArgumentException ( $ "Expected 2D (unbatched) or 3D (batched) input with { input_channels } channels to Conv1d.") ;
60
- }
20
+ if ( ! ValidateShape ( input , 1 ) )
21
+ throw new ArgumentException ( $ "Expected 2D (unbatched) or 3D (batched) input with { in_channels } channels to Conv1d.") ;
61
22
62
- public Parameter ? bias {
63
- get {
64
- var res = THSNN_Conv1d_bias ( handle ) ;
65
- if ( res == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
66
- return ( ( res == IntPtr . Zero ) ? null : new Parameter ( res ) ) ;
67
- }
68
- set {
69
- // Please ignore, for now, that the litorch call thinks you *can* set it to null.
70
- if ( value is null ) throw new ArgumentNullException ( "bias cannot be set to 'null'" ) ;
71
- THSNN_Conv1d_set_bias ( handle , ( value is null ? IntPtr . Zero : value . Handle ) ) ;
72
- torch . CheckForErrors ( ) ;
73
- ConditionallyRegisterParameter ( "bias" , value ) ;
74
- }
75
- }
76
- public Parameter ? weight {
77
- get {
78
- var res = THSNN_Conv1d_weight ( handle ) ;
79
- if ( res == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
80
- return ( res == IntPtr . Zero ) ? null : new Parameter ( res ) ;
81
- }
82
- set {
83
- // Please ignore, for now, that the litorch call thinks you *can* set it to null.
84
- if ( value is null ) throw new ArgumentNullException ( "weight cannot be set to 'null'" ) ;
85
- THSNN_Conv1d_set_weight ( handle , value is null ? IntPtr . Zero : value . Handle ) ;
86
- torch . CheckForErrors ( ) ;
87
- ConditionallyRegisterParameter ( "weight" , value ) ;
23
+ if ( padding_mode != PaddingModes . Zeros ) {
24
+ using var paddedInput = torch . nn . functional . pad ( input , _reversed_padding_repeated_twice , padding_mode ) ;
25
+ return torch . nn . functional . conv1d ( paddedInput , weight , bias , stride [ 0 ] , 0 , dilation [ 0 ] , groups ) ;
88
26
}
27
+
28
+ if ( padding_type . HasValue )
29
+ return torch . nn . functional . conv1d_padding ( input , weight , bias , stride [ 0 ] , padding_type . Value , dilation [ 0 ] , groups ) ;
30
+
31
+ return torch . nn . functional . conv1d ( input , weight , bias , stride [ 0 ] , padding ? [ 0 ] , dilation [ 0 ] , groups ) ;
89
32
}
90
33
}
91
34
}
@@ -111,9 +54,7 @@ public static partial class nn
111
54
/// <returns>Tensor of shape (N,C_out,L_out)</returns>
112
55
public static Conv1d Conv1d ( long in_channels , long out_channels , long kernel_size , long stride = 1 , long padding = 0 , long dilation = 1 , PaddingModes padding_mode = PaddingModes . Zeros , long groups = 1 , bool bias = true , Device ? device = null , ScalarType ? dtype = null )
113
56
{
114
- var res = THSNN_Conv1d_ctor ( in_channels , out_channels , kernel_size , stride , padding , dilation , ( long ) padding_mode , groups , bias , out var boxedHandle ) ;
115
- if ( res == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
116
- return new Conv1d ( res , boxedHandle , in_channels ) . MoveModule < Conv1d > ( device , dtype ) ;
57
+ return new Conv1d ( in_channels , out_channels , kernel_size , stride , padding , null , dilation , groups , bias , padding_mode , device , dtype ) ;
117
58
}
118
59
119
60
/// <summary>
@@ -133,9 +74,7 @@ public static Conv1d Conv1d(long in_channels, long out_channels, long kernel_siz
133
74
/// <returns>Tensor of shape (N,C_out,L_out)</returns>
134
75
public static Conv1d Conv1d ( long in_channels , long out_channels , long kernel_size , Padding padding , long stride = 1 , long dilation = 1 , PaddingModes padding_mode = PaddingModes . Zeros , long groups = 1 , bool bias = true , Device ? device = null , ScalarType ? dtype = null )
135
76
{
136
- var res = THSNN_Conv1d_ctor ( in_channels , out_channels , kernel_size , stride , padding == Padding . Valid ? 0 : - 1 , dilation , ( long ) padding_mode , groups , bias , out var boxedHandle ) ;
137
- if ( res == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
138
- return new Conv1d ( res , boxedHandle , in_channels ) . MoveModule < Conv1d > ( device , dtype ) ;
77
+ return new Conv1d ( in_channels , out_channels , kernel_size , stride , null , padding , dilation , groups , bias , padding_mode , device , dtype ) ;
139
78
}
140
79
141
80
public static partial class functional
@@ -144,12 +83,12 @@ public static partial class functional
144
83
/// Applies a 1D convolution over an input signal composed of several input planes.
145
84
/// </summary>
146
85
/// <param name="input">The input tensor.</param>
147
- /// <param name="weight"></param>
148
- /// <param name="bias"></param>
149
- /// <param name="stride"></param>
150
- /// <param name="padding"></param>
151
- /// <param name="dilation"></param>
152
- /// <param name="groups"></param>
86
+ /// <param name="weight">weight matrix of the convolution </param>
87
+ /// <param name="bias">Optional; bias vector of the convolution </param>
88
+ /// <param name="stride">Stride of the convolution. Default: (1,) </param>
89
+ /// <param name="padding">Zero-padding added to both sides of the input. Default: (0,) </param>
90
+ /// <param name="dilation">Spacing between kernel elements. Default: (1,) </param>
91
+ /// <param name="groups">Number of blocked connections from input channels to output channels. Default: 1 </param>
153
92
/// <returns></returns>
154
93
public static Tensor conv1d ( Tensor input , Tensor weight , Tensor ? bias = null ,
155
94
long ? stride = null ,
@@ -175,6 +114,39 @@ public static Tensor conv1d(Tensor input, Tensor weight, Tensor? bias = null,
175
114
}
176
115
}
177
116
117
+ /// <summary>
118
+ /// Applies a 1D convolution over an input signal composed of several input planes.
119
+ /// </summary>
120
+ /// <param name="input">The input tensor.</param>
121
+ /// <param name="weight">weight matrix of the convolution</param>
122
+ /// <param name="bias">Optional; bias vector of the convolution</param>
123
+ /// <param name="stride">Stride of the convolution. Default: (1,)</param>
124
+ /// <param name="padding">Zero-padding added to both sides of the input. padding=Valid is the same as no padding. padding=Same pads the input so the output has the shape as the input. </param>
125
+ /// <param name="dilation">Spacing between kernel elements. Default: (1,)</param>
126
+ /// <param name="groups">Number of blocked connections from input channels to output channels. Default: 1</param>
127
+ /// <returns></returns>
128
+ public static Tensor conv1d_padding ( Tensor input , Tensor weight , Tensor ? bias = null ,
129
+ long ? stride = null ,
130
+ Padding padding = Padding . Valid ,
131
+ long ? dilation = null ,
132
+ long groups = 1 )
133
+ {
134
+ var strides = new long [ ] { stride ?? 1 } ;
135
+ var dilationArray = new long [ ] { dilation ?? 1 } ;
136
+ var biasHandle = ( bias is null ? IntPtr . Zero : bias . Handle ) ;
137
+ unsafe {
138
+ fixed ( long * pstrides = strides , pdilation = dilationArray ) {
139
+ var res =
140
+ THSTensor_conv1d_padding ( input . Handle , weight . Handle , biasHandle ,
141
+ ( IntPtr ) pstrides , strides . Length ,
142
+ ( int ) padding ,
143
+ ( IntPtr ) pdilation , dilationArray . Length ,
144
+ groups ) ;
145
+ if ( res == IntPtr . Zero ) { torch . CheckForErrors ( ) ; }
146
+ return new Tensor ( res ) ;
147
+ }
148
+ }
149
+ }
178
150
}
179
151
}
180
152
}
0 commit comments