Skip to content

Commit ccf7893

Browse files
Merge branch 'unit' of https://github.com/NiklasGustafsson/TorchSharp into unit
2 parents a90d679 + 58a9b5d commit ccf7893

File tree

15 files changed

+562
-483
lines changed

15 files changed

+562
-483
lines changed

src/Native/LibTorchSharp/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ set(SOURCES
2222
crc32c.c
2323
THSActivation.cpp
2424
THSAutograd.cpp
25-
THSConvolution.cpp
26-
THSData.cpp
25+
THSData.cpp
2726
THSFFT.cpp
2827
THSJIT.cpp
2928
THSLinearAlgebra.cpp

src/Native/LibTorchSharp/THSNN.h

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -37,50 +37,6 @@ EXPORT_API(void) THSNN_AnyModule_dispose(const NNAnyModule module);
3737

3838
EXPORT_API(NNModule) THSNN_custom_module(const char* name, Tensor(*forward)(Tensor), NNAnyModule* outAsAnyModule);
3939

40-
// Convolution
41-
42-
EXPORT_API(NNModule) THSNN_Conv1d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
43-
EXPORT_API(Tensor) THSNN_Conv1d_forward(const NNModule module, const Tensor tensor);
44-
EXPORT_API(Tensor) THSNN_Conv1d_bias(const NNModule module);
45-
EXPORT_API(void) THSNN_Conv1d_set_bias(const NNModule module, const Tensor bias);
46-
EXPORT_API(Tensor) THSNN_Conv1d_weight(const NNModule module);
47-
EXPORT_API(void) THSNN_Conv1d_set_weight(const NNModule module, const Tensor weight);
48-
EXPORT_API(NNModule) THSNN_Conv2d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
49-
EXPORT_API(NNModule) THSNN_Conv2d_ctor_1(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelX, const int64_t kernelY, const int64_t strideX, const int64_t strideY, const int64_t paddingX, const int64_t paddingY, const int64_t dilationX, const int64_t dilationY, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
50-
EXPORT_API(Tensor) THSNN_Conv2d_forward(const NNModule module, const Tensor tensor);
51-
EXPORT_API(Tensor) THSNN_Conv2d_weight(const NNModule module);
52-
EXPORT_API(void) THSNN_Conv2d_set_weight(const NNModule module, const Tensor weight);
53-
EXPORT_API(Tensor) THSNN_Conv2d_bias(const NNModule module);
54-
EXPORT_API(void) THSNN_Conv2d_set_bias(const NNModule module, const Tensor bias);
55-
EXPORT_API(NNModule) THSNN_Conv3d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
56-
EXPORT_API(NNModule) THSNN_Conv3d_ctor_1(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelX, const int64_t kernelY, const int64_t kernelZ, const int64_t strideX, const int64_t strideY, const int64_t strideZ, const int64_t paddingX, const int64_t paddingY, const int64_t paddingZ, const int64_t dilationX, const int64_t dilationY, const int64_t dilationZ, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
57-
EXPORT_API(Tensor) THSNN_Conv3d_forward(const NNModule module, const Tensor tensor);
58-
EXPORT_API(Tensor) THSNN_Conv3d_weight(const NNModule module);
59-
EXPORT_API(void) THSNN_Conv3d_set_weight(const NNModule module, const Tensor weight);
60-
EXPORT_API(Tensor) THSNN_Conv3d_bias(const NNModule module);
61-
EXPORT_API(void) THSNN_Conv3d_set_bias(const NNModule module, const Tensor bias);
62-
63-
EXPORT_API(NNModule) THSNN_ConvTranspose1d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t output_padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
64-
EXPORT_API(Tensor) THSNN_ConvTranspose1d_forward(const NNModule module, const Tensor tensor);
65-
EXPORT_API(Tensor) THSNN_ConvTranspose1d_bias(const NNModule module);
66-
EXPORT_API(void) THSNN_ConvTranspose1d_set_bias(const NNModule module, const Tensor bias);
67-
EXPORT_API(Tensor) THSNN_ConvTranspose1d_weight(const NNModule module);
68-
EXPORT_API(void) THSNN_ConvTranspose1d_set_weight(const NNModule module, const Tensor weight);
69-
EXPORT_API(NNModule) THSNN_ConvTranspose2d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t output_padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
70-
EXPORT_API(NNModule) THSNN_ConvTranspose2d_ctor_1(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelX, const int64_t kernelY, const int64_t strideX, const int64_t strideY, const int64_t paddingX, const int64_t paddingY, const int64_t output_paddingX, const int64_t output_paddingY, const int64_t dilationX, const int64_t dilationY, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
71-
EXPORT_API(Tensor) THSNN_ConvTranspose2d_forward(const NNModule module, const Tensor tensor);
72-
EXPORT_API(Tensor) THSNN_ConvTranspose2d_weight(const NNModule module);
73-
EXPORT_API(void) THSNN_ConvTranspose2d_set_weight(const NNModule module, const Tensor weight);
74-
EXPORT_API(Tensor) THSNN_ConvTranspose2d_bias(const NNModule module);
75-
EXPORT_API(void) THSNN_ConvTranspose2d_set_bias(const NNModule module, const Tensor bias);
76-
EXPORT_API(NNModule) THSNN_ConvTranspose3d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t output_padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
77-
EXPORT_API(NNModule) THSNN_ConvTranspose3d_ctor_1(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelX, const int64_t kernelY, const int64_t kernelZ, const int64_t strideX, const int64_t strideY, const int64_t strideZ, const int64_t paddingX, const int64_t paddingY, const int64_t paddingZ, const int64_t output_paddingX, const int64_t output_paddingY, const int64_t output_paddingZ, const int64_t dilationX, const int64_t dilationY, const int64_t dilationZ, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
78-
EXPORT_API(Tensor) THSNN_ConvTranspose3d_forward(const NNModule module, const Tensor tensor);
79-
EXPORT_API(Tensor) THSNN_ConvTranspose3d_weight(const NNModule module);
80-
EXPORT_API(void) THSNN_ConvTranspose3d_set_weight(const NNModule module, const Tensor weight);
81-
EXPORT_API(Tensor) THSNN_ConvTranspose3d_bias(const NNModule module);
82-
EXPORT_API(void) THSNN_ConvTranspose3d_set_bias(const NNModule module, const Tensor bias);
83-
8440
// Normalization
8541

8642
EXPORT_API(Tensor) THSNN_batch_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool training, const double momentum, const double eps);

src/Native/LibTorchSharp/THSTensor.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,24 @@ EXPORT_API(Tensor) THSTensor_conv3d(const Tensor input, const Tensor weight, con
314314
const int64_t* dilations, const int dilations_length,
315315
int64_t groups);
316316

317+
EXPORT_API(Tensor) THSTensor_conv1d_padding(const Tensor input, const Tensor weight, const Tensor bias,
318+
const int64_t* strides, const int strides_length,
319+
const int padding,
320+
const int64_t* dilations, const int dilations_length,
321+
int64_t groups);
322+
323+
EXPORT_API(Tensor) THSTensor_conv2d_padding(const Tensor input, const Tensor weight, const Tensor bias,
324+
const int64_t* strides, const int strides_length,
325+
const int padding,
326+
const int64_t* dilations, const int dilations_length,
327+
int64_t groups);
328+
329+
EXPORT_API(Tensor) THSTensor_conv3d_padding(const Tensor input, const Tensor weight, const Tensor bias,
330+
const int64_t* strides, const int strides_length,
331+
const int padding,
332+
const int64_t* dilations, const int dilations_length,
333+
int64_t groups);
334+
317335
EXPORT_API(Tensor) THSTensor_conv_transpose1d(const Tensor input, const Tensor weight, const Tensor bias,
318336
const int64_t* strides, const int strides_length,
319337
const int64_t* paddings, const int paddings_length,

src/Native/LibTorchSharp/THSTensorConv.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,63 @@ Tensor THSTensor_conv3d(
340340
groups));
341341
}
342342

343+
static c10::string_view get_padding_str(int padding) {
344+
if (padding == 0)
345+
return "valid";
346+
else if (padding == 1)
347+
return "same";
348+
349+
TORCH_CHECK(false, "Invalid padding string specified");
350+
}
351+
352+
Tensor THSTensor_conv1d_padding(
353+
const Tensor input,
354+
const Tensor weight,
355+
const Tensor bias,
356+
const int64_t* stride, const int strideLength,
357+
const int padding,
358+
const int64_t* dilation, const int dilationLength,
359+
int64_t groups)
360+
{
361+
CATCH_TENSOR(torch::conv1d(*input, *weight, (bias ? *bias : at::Tensor()),
362+
at::ArrayRef<int64_t>(stride, strideLength),
363+
get_padding_str(padding),
364+
at::ArrayRef<int64_t>(dilation, dilationLength),
365+
groups));
366+
}
367+
368+
369+
Tensor THSTensor_conv2d_padding(
370+
const Tensor input,
371+
const Tensor weight,
372+
const Tensor bias,
373+
const int64_t* stride, const int strideLength,
374+
const int padding,
375+
const int64_t* dilation, const int dilationLength,
376+
int64_t groups)
377+
{
378+
CATCH_TENSOR(torch::conv2d(*input, *weight, (bias ? *bias : at::Tensor()),
379+
at::ArrayRef<int64_t>(stride, strideLength),
380+
get_padding_str(padding),
381+
at::ArrayRef<int64_t>(dilation, dilationLength),
382+
groups));
383+
}
384+
385+
Tensor THSTensor_conv3d_padding(
386+
const Tensor input,
387+
const Tensor weight,
388+
const Tensor bias,
389+
const int64_t* stride, const int strideLength,
390+
const int padding,
391+
const int64_t* dilation, const int dilationLength,
392+
int64_t groups)
393+
{
394+
CATCH_TENSOR(torch::conv3d(*input, *weight, (bias ? *bias : at::Tensor()),
395+
at::ArrayRef<int64_t>(stride, strideLength),
396+
get_padding_str(padding),
397+
at::ArrayRef<int64_t>(dilation, dilationLength),
398+
groups));
399+
}
343400

344401
Tensor THSTensor_max_pool1d(
345402
const Tensor tensor,

src/TorchSharp/NN/Convolution/Conv1D.cs

Lines changed: 53 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -8,84 +8,27 @@ namespace TorchSharp
88
{
99
using Modules;
1010

11-
public enum PaddingModes
12-
{
13-
Zeros = 0,
14-
Reflect = 1,
15-
Replicate = 2,
16-
Circular = 3,
17-
Constant = 4,
18-
}
19-
20-
public enum Padding
21-
{
22-
Valid = 0,
23-
Same = 1
24-
}
25-
2611
namespace Modules
2712
{
28-
public abstract class Convolution : torch.nn.Module<Tensor, Tensor>
29-
{
30-
protected Convolution(IntPtr handle, IntPtr boxedHandle, long input_channels) : base(handle, boxedHandle)
31-
{
32-
this.input_channels = input_channels;
33-
}
34-
35-
protected bool ValidateShape(Tensor input, long dimensions)
36-
{
37-
var shape = input.shape;
38-
var ndim = shape.LongLength;
39-
40-
return (ndim == dimensions+2) && (input.shape[1] == input_channels) || // Batched: N + C + dims
41-
(ndim == dimensions+1 && input.shape[0] == input_channels); // Unbathced: C + dims
42-
43-
}
44-
45-
protected long input_channels;
46-
}
47-
4813
public sealed class Conv1d : Convolution
4914
{
50-
internal Conv1d(IntPtr handle, IntPtr boxedHandle, long input_channels) : base(handle, boxedHandle, input_channels) { }
15+
internal Conv1d(long in_channels, long out_channels, long kernel_size, long stride, long? padding, Padding? padding_type, long dilation, long groups = 1, bool bias = true, PaddingModes padding_mode = PaddingModes.Zeros, torch.Device? device = null, ScalarType? dtype = null)
16+
: base(nameof(Conv1d), in_channels, out_channels, new[] { kernel_size }, new[] { stride }, padding.HasValue ? new[] { padding.Value } : null, padding_type, new[] { dilation }, false, new[] { 0L }, groups, bias, padding_mode, device, dtype) { }
5117

5218
public override Tensor forward(Tensor input)
5319
{
54-
if (ValidateShape(input, 1)) {
55-
var res = THSNN_Conv1d_forward(handle, input.Handle);
56-
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
57-
return new Tensor(res);
58-
}
59-
throw new ArgumentException($"Expected 2D (unbatched) or 3D (batched) input with {input_channels} channels to Conv1d.");
60-
}
20+
if (!ValidateShape(input, 1))
21+
throw new ArgumentException($"Expected 2D (unbatched) or 3D (batched) input with {in_channels} channels to Conv1d.");
6122

62-
public Parameter? bias {
63-
get {
64-
var res = THSNN_Conv1d_bias(handle);
65-
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
66-
return ((res == IntPtr.Zero) ? null : new Parameter(res));
67-
}
68-
set {
69-
// Please ignore, for now, that the litorch call thinks you *can* set it to null.
70-
if (value is null) throw new ArgumentNullException("bias cannot be set to 'null'");
71-
THSNN_Conv1d_set_bias(handle, (value is null ? IntPtr.Zero : value.Handle));
72-
torch.CheckForErrors();
73-
ConditionallyRegisterParameter("bias", value);
74-
}
75-
}
76-
public Parameter? weight {
77-
get {
78-
var res = THSNN_Conv1d_weight(handle);
79-
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
80-
return (res == IntPtr.Zero) ? null : new Parameter(res);
81-
}
82-
set {
83-
// Please ignore, for now, that the litorch call thinks you *can* set it to null.
84-
if (value is null) throw new ArgumentNullException("weight cannot be set to 'null'");
85-
THSNN_Conv1d_set_weight(handle, value is null ? IntPtr.Zero : value.Handle);
86-
torch.CheckForErrors();
87-
ConditionallyRegisterParameter("weight", value);
23+
if (padding_mode != PaddingModes.Zeros) {
24+
using var paddedInput = torch.nn.functional.pad(input, _reversed_padding_repeated_twice, padding_mode);
25+
return torch.nn.functional.conv1d(paddedInput, weight, bias, stride[0], 0, dilation[0], groups);
8826
}
27+
28+
if (padding_type.HasValue)
29+
return torch.nn.functional.conv1d_padding(input, weight, bias, stride[0], padding_type.Value, dilation[0], groups);
30+
31+
return torch.nn.functional.conv1d(input, weight, bias, stride[0], padding?[0], dilation[0], groups);
8932
}
9033
}
9134
}
@@ -111,9 +54,7 @@ public static partial class nn
11154
/// <returns>Tensor of shape (N,C_out,L_out)</returns>
11255
public static Conv1d Conv1d(long in_channels, long out_channels, long kernel_size, long stride = 1, long padding = 0, long dilation = 1, PaddingModes padding_mode = PaddingModes.Zeros, long groups = 1, bool bias = true, Device? device = null, ScalarType? dtype = null)
11356
{
114-
var res = THSNN_Conv1d_ctor(in_channels, out_channels, kernel_size, stride, padding, dilation, (long)padding_mode, groups, bias, out var boxedHandle);
115-
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
116-
return new Conv1d(res, boxedHandle, in_channels).MoveModule<Conv1d>(device, dtype);
57+
return new Conv1d(in_channels, out_channels, kernel_size, stride, padding, null, dilation, groups, bias, padding_mode, device, dtype);
11758
}
11859

11960
/// <summary>
@@ -133,9 +74,7 @@ public static Conv1d Conv1d(long in_channels, long out_channels, long kernel_siz
13374
/// <returns>Tensor of shape (N,C_out,L_out)</returns>
13475
public static Conv1d Conv1d(long in_channels, long out_channels, long kernel_size, Padding padding, long stride = 1, long dilation = 1, PaddingModes padding_mode = PaddingModes.Zeros, long groups = 1, bool bias = true, Device? device = null, ScalarType? dtype = null)
13576
{
136-
var res = THSNN_Conv1d_ctor(in_channels, out_channels, kernel_size, stride, padding == Padding.Valid ? 0 : -1, dilation, (long)padding_mode, groups, bias, out var boxedHandle);
137-
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
138-
return new Conv1d(res, boxedHandle, in_channels).MoveModule<Conv1d>(device, dtype);
77+
return new Conv1d(in_channels, out_channels, kernel_size, stride, null, padding, dilation, groups, bias, padding_mode, device, dtype);
13978
}
14079

14180
public static partial class functional
@@ -144,12 +83,12 @@ public static partial class functional
14483
/// Applies a 1D convolution over an input signal composed of several input planes.
14584
/// </summary>
14685
/// <param name="input">The input tensor.</param>
147-
/// <param name="weight"></param>
148-
/// <param name="bias"></param>
149-
/// <param name="stride"></param>
150-
/// <param name="padding"></param>
151-
/// <param name="dilation"></param>
152-
/// <param name="groups"></param>
86+
/// <param name="weight">weight matrix of the convolution</param>
87+
/// <param name="bias">Optional; bias vector of the convolution</param>
88+
/// <param name="stride">Stride of the convolution. Default: (1,)</param>
89+
/// <param name="padding">Zero-padding added to both sides of the input. Default: (0,)</param>
90+
/// <param name="dilation">Spacing between kernel elements. Default: (1,)</param>
91+
/// <param name="groups">Number of blocked connections from input channels to output channels. Default: 1</param>
15392
/// <returns></returns>
15493
public static Tensor conv1d(Tensor input, Tensor weight, Tensor? bias = null,
15594
long? stride = null,
@@ -175,6 +114,39 @@ public static Tensor conv1d(Tensor input, Tensor weight, Tensor? bias = null,
175114
}
176115
}
177116

117+
/// <summary>
118+
/// Applies a 1D convolution over an input signal composed of several input planes.
119+
/// </summary>
120+
/// <param name="input">The input tensor.</param>
121+
/// <param name="weight">weight matrix of the convolution</param>
122+
/// <param name="bias">Optional; bias vector of the convolution</param>
123+
/// <param name="stride">Stride of the convolution. Default: (1,)</param>
124+
/// <param name="padding">Zero-padding added to both sides of the input. padding=Valid is the same as no padding. padding=Same pads the input so the output has the shape as the input. </param>
125+
/// <param name="dilation">Spacing between kernel elements. Default: (1,)</param>
126+
/// <param name="groups">Number of blocked connections from input channels to output channels. Default: 1</param>
127+
/// <returns></returns>
128+
public static Tensor conv1d_padding(Tensor input, Tensor weight, Tensor? bias = null,
129+
long? stride = null,
130+
Padding padding = Padding.Valid,
131+
long? dilation = null,
132+
long groups = 1)
133+
{
134+
var strides = new long[] { stride ?? 1 };
135+
var dilationArray = new long[] { dilation ?? 1 };
136+
var biasHandle = (bias is null ? IntPtr.Zero : bias.Handle);
137+
unsafe {
138+
fixed (long* pstrides = strides, pdilation = dilationArray) {
139+
var res =
140+
THSTensor_conv1d_padding(input.Handle, weight.Handle, biasHandle,
141+
(IntPtr)pstrides, strides.Length,
142+
(int)padding,
143+
(IntPtr)pdilation, dilationArray.Length,
144+
groups);
145+
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
146+
return new Tensor(res);
147+
}
148+
}
149+
}
178150
}
179151
}
180152
}

0 commit comments

Comments
 (0)