Skip to content

Commit 14e90a4

Browse files
committed
improved new module capabilities plus test cases.
Fixes few bugs.
1 parent 7ed3a7b commit 14e90a4

File tree

6 files changed

+192
-35
lines changed

6 files changed

+192
-35
lines changed

Examples/MNIST.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ private class Model : NN.Module
4545
private NN.Module fc1 = Linear(320, 50);
4646
private NN.Module fc2 = Linear(50, 10);
4747

48-
public Model() : base(IntPtr.Zero)
48+
public Model() : base()
4949
{
5050
RegisterModule(conv1);
5151
RegisterModule(conv2);
@@ -68,7 +68,7 @@ public override ITorchTensor<float> Forward<T>(params ITorchTensor<T>[] tensors)
6868

6969
using (var l31 = fc1.Forward(x))
7070
using (var l32 = Relu(l31))
71-
using (var l33 = Dropout(l32, 0.5, _isTraining))
71+
using (var l33 = Dropout(l32, 0.5, IsTraining()))
7272

7373
using (var l41 = fc2.Forward(l33))
7474

Test/TorchSharp/TorchSharp.cs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3+
using System.Linq;
34
using TorchSharp.JIT;
45
using TorchSharp.Tensor;
56

@@ -314,6 +315,37 @@ public void TestWeightAndBiasShapeInLinear()
314315
Assert.AreEqual(lin.Bias.Shape[0], 100);
315316
}
316317

318+
[TestMethod]
319+
public void TestWeightAndBiasParametersInLinear()
320+
{
321+
var lin = NN.Module.Linear(1000, 100, true);
322+
var names = lin.NamedParameters().Select(p => p.name);
323+
Assert.IsTrue(names.Contains("weight"));
324+
Assert.IsTrue(names.Contains("bias"));
325+
}
326+
327+
[TestMethod]
328+
public void TestWeightParameterInLinear()
329+
{
330+
var lin = NN.Module.Linear(1000, 100, false);
331+
var names = lin.NamedParameters().Select(p => p.name);
332+
Assert.IsTrue(names.Contains("weight"));
333+
Assert.IsFalse(names.Contains("bias"));
334+
}
335+
336+
[TestMethod]
337+
public void TestWeightAndBiasShapeInLinear3()
338+
{
339+
var lin = NN.Module.Linear(1000, 100, true);
340+
var weight = lin.GetParameter("weight");
341+
var bias = lin.GetParameter("bias");
342+
Assert.AreEqual(weight.Shape.Length, 2);
343+
Assert.AreEqual(weight.Shape[0], 100);
344+
Assert.AreEqual(weight.Shape[1], 1000);
345+
Assert.AreEqual(bias.Shape.Length, 1);
346+
Assert.AreEqual(bias.Shape[0], 100);
347+
}
348+
317349
[TestMethod]
318350
public void TestLinearWithBias()
319351
{
@@ -368,6 +400,7 @@ public void CreateSequence()
368400
var lin2 = NN.Module.Linear(100, 10);
369401
var seq = NN.Module.Sequential(lin1, NN.Module.Relu(), lin2);
370402
var modules = seq.GetModules();
403+
Assert.AreEqual(modules.Count(), 3);
371404
}
372405

373406
[TestMethod]
@@ -552,6 +585,48 @@ public void TestMul()
552585
}
553586
}
554587

588+
[TestMethod]
589+
public void TestCustomModule()
590+
{
591+
var module = new TestModule("test", FloatTensor.RandomN(new long[] { 2, 2 }), true);
592+
var name = module.GetName();
593+
Assert.IsNotNull(name);
594+
Assert.IsTrue(module.HasParameter("test"));
595+
}
596+
597+
[TestMethod]
598+
public void TestCustomModuleWithInPlaceModification()
599+
{
600+
var param = FloatTensor.RandomN(new long[] { 1000, 100 });
601+
var module = new TestModule("test", param, true);
602+
603+
Assert.AreEqual(module.GetParameter("test").Shape[0], 1000);
604+
Assert.AreEqual(module.GetParameter("test").Shape[1], 100);
605+
606+
using (var grad = new AutoGradMode(false))
607+
{
608+
param.TransposeInPlace(0, 1);
609+
}
610+
Assert.AreEqual(module.GetParameter("test").Shape[0], 100);
611+
Assert.AreEqual(module.GetParameter("test").Shape[1], 1000);
612+
Assert.AreEqual(param.Shape[0], 100);
613+
Assert.AreEqual(param.Shape[1], 1000);
614+
}
615+
616+
private class TestModule : NN.Module
617+
{
618+
public TestModule(string name, ITorchTensor<float> tensor, bool withGrad)
619+
: base((name, tensor, withGrad))
620+
{
621+
}
622+
623+
public override ITorchTensor<float> Forward<T>(params ITorchTensor<T>[] tensors)
624+
{
625+
throw new NotImplementedException();
626+
}
627+
}
628+
629+
555630
/// <summary>
556631
/// Fully connected Relu net with one hidden layer trained using gradient descent.
557632
/// Taken from <see cref="https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_nn.html"/>.

TorchSharp/NN/Dropout.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace TorchSharp.NN
1010
public class Dropout : FunctionalModule<Dropout>
1111
{
1212
private double _probability;
13+
private bool _isTraining;
1314

1415
internal Dropout(double probability, bool isTraining) : base()
1516
{

TorchSharp/NN/FunctionalModule.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ public override void ZeroGrad()
2121
{
2222
}
2323

24+
public override IEnumerable<(string name, ITorchTensor<float> parameter)> NamedParameters()
25+
{
26+
return new List<(string, ITorchTensor<float>)>();
27+
}
28+
2429
public override IEnumerable<ITorchTensor<float>> Parameters()
2530
{
2631
return new List<ITorchTensor<float>>();

TorchSharp/NN/Linear.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ public Linear(IntPtr handle) : base(handle)
1111
}
1212

1313
[DllImport("libTorchSharp")]
14-
extern static IntPtr THSNN_linearModule(int input, int output, bool hasBias);
14+
extern static IntPtr THSNN_linearModule(long input_size, long output_size, bool with_bias);
1515

16-
public Linear(int input, int output, bool hasBias = false) : base()
16+
public Linear(long inputSize, long outputSize, bool hasBias = false) : base()
1717
{
18-
handle = new HType(THSNN_linearModule(input, output, hasBias), true);
18+
handle = new HType(THSNN_linearModule(inputSize, outputSize, hasBias), true);
1919
}
2020

2121
[DllImport("libTorchSharp")]

0 commit comments

Comments
 (0)