Skip to content

Commit 91f9bc2

Browse files
committed
Added modules for MNIST model
1 parent 01f0b7b commit 91f9bc2

File tree

9 files changed

+228
-30
lines changed

9 files changed

+228
-30
lines changed

TorchSharp/NN/Conv2D.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System;
2+
using System.Runtime.InteropServices;
3+
using TorchSharp.Tensor;
4+
5+
namespace TorchSharp.NN
6+
{
7+
public class Conv2D : Module
8+
{
9+
internal Conv2D(IntPtr handle) : base(handle)
10+
{
11+
}
12+
13+
[DllImport("LibTorchSharp")]
14+
extern static FloatTensor.HType NN_conv2DModule_Forward(Module.HType module, IntPtr tensor);
15+
16+
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
17+
{
18+
return new FloatTensor(NN_conv2DModule_Forward(handle, tensor.Handle));
19+
}
20+
}
21+
}

TorchSharp/NN/Dropout.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using System;
2+
using System.Runtime.InteropServices;
3+
using TorchSharp.Tensor;
4+
5+
namespace TorchSharp.NN
6+
{
7+
/// <summary>
8+
/// This class is used to represent a dropout module.
9+
/// </summary>
10+
public class Dropout : FunctionalModule
11+
{
12+
private double _probability;
13+
private Func<bool> _isTraining;
14+
15+
internal Dropout(double probability, Func<bool> isTraining) : base()
16+
{
17+
_probability = probability;
18+
_isTraining = isTraining;
19+
}
20+
21+
[DllImport("LibTorchSharp")]
22+
extern static FloatTensor.HType NN_LogSoftMaxModule_Forward(IntPtr tensor, double probability, bool isTraining);
23+
24+
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
25+
{
26+
return new FloatTensor(NN_LogSoftMaxModule_Forward(tensor.Handle, _probability, _isTraining.Invoke()));
27+
}
28+
}
29+
}

TorchSharp/NN/Functional.cs

Lines changed: 0 additions & 26 deletions
This file was deleted.

TorchSharp/NN/FunctionalModule.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
4+
using TorchSharp.Tensor;
5+
6+
namespace TorchSharp.NN
7+
{
8+
/// <summary>
9+
/// This class is used to represent a functional module (e.g., ReLU).
10+
/// </summary>
11+
public abstract class FunctionalModule : Module
12+
{
13+
internal FunctionalModule() : base(IntPtr.Zero)
14+
{
15+
}
16+
17+
public override void ZeroGrad()
18+
{
19+
}
20+
21+
public override bool IsTraining()
22+
{
23+
return true;
24+
}
25+
26+
public override IEnumerable<ITorchTensor<float>> Parameters()
27+
{
28+
return new List<ITorchTensor<float>>();
29+
}
30+
31+
public override string[] GetModules()
32+
{
33+
return new string[0];
34+
}
35+
}
36+
}

TorchSharp/NN/LogSoftMax.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
using System.Runtime.InteropServices;
3+
using TorchSharp.Tensor;
4+
5+
namespace TorchSharp.NN
6+
{
7+
/// <summary>
8+
/// This class is used to represent a log softmax module.
9+
/// </summary>
10+
public class LogSoftMax : FunctionalModule
11+
{
12+
private long _dimension;
13+
14+
internal LogSoftMax(long dimension) : base()
15+
{
16+
_dimension = dimension;
17+
}
18+
19+
[DllImport("LibTorchSharp")]
20+
extern static FloatTensor.HType NN_LogSoftMaxModule_Forward(IntPtr tensor, long dimension);
21+
22+
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
23+
{
24+
return new FloatTensor(NN_LogSoftMaxModule_Forward(tensor.Handle, _dimension));
25+
}
26+
}
27+
}

TorchSharp/NN/MaxPool2D.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
using System.Runtime.InteropServices;
3+
using TorchSharp.Tensor;
4+
5+
namespace TorchSharp.NN
6+
{
7+
/// <summary>
8+
/// This class is used to represent a ReLu module.
9+
/// </summary>
10+
public class MaxPool2D : FunctionalModule
11+
{
12+
private long _kernelSize;
13+
14+
internal MaxPool2D(long kernelSize) : base()
15+
{
16+
_kernelSize = kernelSize;
17+
}
18+
19+
[DllImport("LibTorchSharp")]
20+
extern static FloatTensor.HType NN_MaxPool2DModule_Forward(IntPtr tensor, long kernelSize);
21+
22+
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
23+
{
24+
return new FloatTensor(NN_MaxPool2DModule_Forward(tensor.Handle, _kernelSize));
25+
}
26+
}
27+
}

TorchSharp/NN/Module.cs

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ protected void Dispose(bool disposing)
7373
}
7474
}
7575

76-
public abstract partial class Module : IDisposable
76+
public partial class Module : IDisposable
7777
{
7878
static public Sequential Sequential(params Module[] modules)
7979
{
@@ -89,15 +89,62 @@ static public Module Linear(int input, int output, bool hasBias = false)
8989
}
9090

9191
[DllImport("LibTorchSharp")]
92-
extern static IntPtr NN_reluModule();
92+
extern static IntPtr NN_conv2dModule(long inputChannel, long outputChannel, int kernelSize);
93+
94+
static public Module Conv2D(long inputChannel, long outputChannel, int kernelSize)
95+
{
96+
return new Conv2D(NN_conv2dModule(inputChannel, outputChannel, kernelSize));
97+
}
9398

9499
static public Module Relu()
95100
{
96-
return new Functional(NN_reluModule());
101+
return new ReLu();
102+
}
103+
104+
static public ITorchTensor<float> Relu(ITorchTensor<float> x)
105+
{
106+
return new ReLu().Forward(x);
107+
}
108+
109+
static public Module MaxPool2D(long kernelSize)
110+
{
111+
return new MaxPool2D(kernelSize);
112+
}
113+
114+
static public ITorchTensor<float> MaxPool2D(ITorchTensor<float> x, long kernelSize)
115+
{
116+
return new MaxPool2D(kernelSize).Forward(x);
117+
}
118+
119+
static public Module LogSoftMax(long dimension)
120+
{
121+
return new LogSoftMax(dimension);
122+
}
123+
124+
static public ITorchTensor<float> LogSoftMax(ITorchTensor<float> x, long dimension)
125+
{
126+
return new LogSoftMax(dimension).Forward(x);
97127
}
98128

129+
static public Module Dropout(double probability, Func<bool> isTraining)
130+
{
131+
return new Dropout(probability, isTraining);
132+
}
133+
134+
static public ITorchTensor<float> Dropout(ITorchTensor<float> x, double probability, Func<bool> isTraining)
135+
{
136+
return new Dropout(probability, isTraining).Forward(x);
137+
}
138+
}
139+
140+
public abstract partial class Module : IDisposable
141+
{
99142
public abstract ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor);
100143

144+
public virtual void RegisterModule(Module module)
145+
{
146+
}
147+
101148
[DllImport("LibTorchSharp")]
102149
extern static void NN_Module_ZeroGrad(HType module);
103150

@@ -106,6 +153,14 @@ public virtual void ZeroGrad()
106153
NN_Module_ZeroGrad(handle);
107154
}
108155

156+
[DllImport("LibTorchSharp")]
157+
extern static bool NN_IsTraining(HType module);
158+
159+
public virtual bool IsTraining()
160+
{
161+
return NN_IsTraining(handle);
162+
}
163+
109164
[DllImport("LibTorchSharp")]
110165
extern static void NN_GetParameters(HType module, AllocatePinnedArray allocator);
111166

TorchSharp/NN/ReLu.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using System.Runtime.InteropServices;
3+
using TorchSharp.Tensor;
4+
5+
namespace TorchSharp.NN
6+
{
7+
/// <summary>
8+
/// This class is used to represent a ReLu module.
9+
/// </summary>
10+
public class ReLu : FunctionalModule
11+
{
12+
internal ReLu() : base()
13+
{
14+
}
15+
16+
[DllImport("LibTorchSharp")]
17+
extern static FloatTensor.HType NN_ReluModule_Forward(IntPtr tensor);
18+
19+
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
20+
{
21+
return new FloatTensor(NN_ReluModule_Forward(tensor.Handle));
22+
}
23+
}
24+
}

TorchSharp/NN/Sequential.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@ public Sequential(params Module[] modules) : base(IntPtr.Zero)
1818
{
1919
foreach (var module in modules)
2020
{
21-
Modules.Add(module);
21+
RegisterModule(module);
2222
}
2323
}
2424

25+
public override void RegisterModule(Module module)
26+
{
27+
Modules.Add(module);
28+
}
29+
2530
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
2631
{
2732
if (Modules.Count < 1)

0 commit comments

Comments
 (0)