Skip to content

Commit 37f73b3

Browse files
authored
Merge pull request #12 from interesaaat/LibTorchSharpFirstTest
Lib torch sharp first test
2 parents c7a23e5 + 14e90a4 commit 37f73b3

File tree

9 files changed

+335
-65
lines changed

9 files changed

+335
-65
lines changed

Examples/MNIST.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ private class Model : NN.Module
4545
private NN.Module fc1 = Linear(320, 50);
4646
private NN.Module fc2 = Linear(50, 10);
4747

48-
public Model() : base(IntPtr.Zero)
48+
public Model() : base()
4949
{
5050
RegisterModule(conv1);
5151
RegisterModule(conv2);
@@ -68,7 +68,7 @@ public override ITorchTensor<float> Forward<T>(params ITorchTensor<T>[] tensors)
6868

6969
using (var l31 = fc1.Forward(x))
7070
using (var l32 = Relu(l31))
71-
using (var l33 = Dropout(l32, 0.5, _isTraining))
71+
using (var l33 = Dropout(l32, 0.5, IsTraining()))
7272

7373
using (var l41 = fc2.Forward(l33))
7474

Test/TorchSharp/TorchSharp.cs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3+
using System.Linq;
34
using TorchSharp.JIT;
45
using TorchSharp.Tensor;
56

@@ -302,6 +303,88 @@ public void TestSetGetBiasInLinear()
302303
Assert.AreEqual(lin.Bias.NumberOfElements, bias.NumberOfElements);
303304
}
304305

306+
[TestMethod]
307+
public void TestWeightAndBiasShapeInLinear()
308+
{
309+
var lin = NN.Module.Linear(1000, 100, true);
310+
311+
Assert.AreEqual(lin.Weight.Shape.Length, 2);
312+
Assert.AreEqual(lin.Weight.Shape[0], 100);
313+
Assert.AreEqual(lin.Weight.Shape[1], 1000);
314+
Assert.AreEqual(lin.Bias.Shape.Length, 1);
315+
Assert.AreEqual(lin.Bias.Shape[0], 100);
316+
}
317+
318+
[TestMethod]
319+
public void TestWeightAndBiasParametersInLinear()
320+
{
321+
var lin = NN.Module.Linear(1000, 100, true);
322+
var names = lin.NamedParameters().Select(p => p.name);
323+
Assert.IsTrue(names.Contains("weight"));
324+
Assert.IsTrue(names.Contains("bias"));
325+
}
326+
327+
[TestMethod]
328+
public void TestWeightParameterInLinear()
329+
{
330+
var lin = NN.Module.Linear(1000, 100, false);
331+
var names = lin.NamedParameters().Select(p => p.name);
332+
Assert.IsTrue(names.Contains("weight"));
333+
Assert.IsFalse(names.Contains("bias"));
334+
}
335+
336+
[TestMethod]
337+
public void TestWeightAndBiasShapeInLinear3()
338+
{
339+
var lin = NN.Module.Linear(1000, 100, true);
340+
var weight = lin.GetParameter("weight");
341+
var bias = lin.GetParameter("bias");
342+
Assert.AreEqual(weight.Shape.Length, 2);
343+
Assert.AreEqual(weight.Shape[0], 100);
344+
Assert.AreEqual(weight.Shape[1], 1000);
345+
Assert.AreEqual(bias.Shape.Length, 1);
346+
Assert.AreEqual(bias.Shape[0], 100);
347+
}
348+
349+
[TestMethod]
350+
public void TestLinearWithBias()
351+
{
352+
var lin = NN.Module.Linear(1000, 100, true);
353+
var bias = lin.Bias;
354+
var weight = lin.Weight.T();
355+
var input = FloatTensor.RandomN(new long[] { 1, 1000 });
356+
var forward = lin.Forward(input);
357+
var matmul = input.MatMul(weight).Add(bias);
358+
359+
Assert.AreEqual(forward.Shape.Length, matmul.Shape.Length);
360+
Assert.AreEqual(forward.Shape[0], matmul.Shape[0]);
361+
Assert.AreEqual(forward.Shape[1], matmul.Shape[1]);
362+
363+
for (int i = 0; i < 100; i++)
364+
{
365+
Assert.AreEqual(forward.Data[i], matmul.Data[i]);
366+
}
367+
}
368+
369+
[TestMethod]
370+
public void TestLinearNoBias()
371+
{
372+
var lin = NN.Module.Linear(1000, 100, false);
373+
var weight = lin.Weight.Transpose(0, 1);
374+
var input = FloatTensor.RandomN(new long[] { 1, 1000 });
375+
var forward = lin.Forward(input);
376+
var matmul = input.MatMul(weight);
377+
378+
Assert.AreEqual(forward.Shape.Length, matmul.Shape.Length);
379+
Assert.AreEqual(forward.Shape[0], matmul.Shape[0]);
380+
Assert.AreEqual(forward.Shape[1], matmul.Shape[1]);
381+
382+
for (int i = 0; i < 100; i++)
383+
{
384+
Assert.AreEqual(forward.Data[i], matmul.Data[i]);
385+
}
386+
}
387+
305388
[TestMethod]
306389
public void CreateRelu()
307390
{
@@ -317,6 +400,7 @@ public void CreateSequence()
317400
var lin2 = NN.Module.Linear(100, 10);
318401
var seq = NN.Module.Sequential(lin1, NN.Module.Relu(), lin2);
319402
var modules = seq.GetModules();
403+
Assert.AreEqual(modules.Count(), 3);
320404
}
321405

322406
[TestMethod]
@@ -501,6 +585,48 @@ public void TestMul()
501585
}
502586
}
503587

588+
[TestMethod]
589+
public void TestCustomModule()
590+
{
591+
var module = new TestModule("test", FloatTensor.RandomN(new long[] { 2, 2 }), true);
592+
var name = module.GetName();
593+
Assert.IsNotNull(name);
594+
Assert.IsTrue(module.HasParameter("test"));
595+
}
596+
597+
[TestMethod]
598+
public void TestCustomModuleWithInPlaceModification()
599+
{
600+
var param = FloatTensor.RandomN(new long[] { 1000, 100 });
601+
var module = new TestModule("test", param, true);
602+
603+
Assert.AreEqual(module.GetParameter("test").Shape[0], 1000);
604+
Assert.AreEqual(module.GetParameter("test").Shape[1], 100);
605+
606+
using (var grad = new AutoGradMode(false))
607+
{
608+
param.TransposeInPlace(0, 1);
609+
}
610+
Assert.AreEqual(module.GetParameter("test").Shape[0], 100);
611+
Assert.AreEqual(module.GetParameter("test").Shape[1], 1000);
612+
Assert.AreEqual(param.Shape[0], 100);
613+
Assert.AreEqual(param.Shape[1], 1000);
614+
}
615+
616+
private class TestModule : NN.Module
617+
{
618+
public TestModule(string name, ITorchTensor<float> tensor, bool withGrad)
619+
: base((name, tensor, withGrad))
620+
{
621+
}
622+
623+
public override ITorchTensor<float> Forward<T>(params ITorchTensor<T>[] tensors)
624+
{
625+
throw new NotImplementedException();
626+
}
627+
}
628+
629+
504630
/// <summary>
505631
/// Fully connected Relu net with one hidden layer trained using gradient descent.
506632
/// Taken from <see cref="https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_nn.html"/>.

TorchSharp/NN/Dropout.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace TorchSharp.NN
1010
public class Dropout : FunctionalModule<Dropout>
1111
{
1212
private double _probability;
13+
private bool _isTraining;
1314

1415
internal Dropout(double probability, bool isTraining) : base()
1516
{

TorchSharp/NN/FunctionalModule.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ public override void ZeroGrad()
2121
{
2222
}
2323

24+
public override IEnumerable<(string name, ITorchTensor<float> parameter)> NamedParameters()
25+
{
26+
return new List<(string, ITorchTensor<float>)>();
27+
}
28+
2429
public override IEnumerable<ITorchTensor<float>> Parameters()
2530
{
2631
return new List<ITorchTensor<float>>();

TorchSharp/NN/Linear.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ public Linear(IntPtr handle) : base(handle)
1111
}
1212

1313
[DllImport("libTorchSharp")]
14-
extern static IntPtr THSNN_linearModule(int input, int output, bool hasBias);
14+
extern static IntPtr THSNN_linearModule(long input_size, long output_size, bool with_bias);
1515

16-
public Linear(int input, int output, bool hasBias = false) : base()
16+
public Linear(long inputSize, long outputSize, bool hasBias = false) : base()
1717
{
18-
handle = new HType(THSNN_linearModule(input, output, hasBias), true);
18+
handle = new HType(THSNN_linearModule(inputSize, outputSize, hasBias), true);
1919
}
2020

2121
[DllImport("libTorchSharp")]

0 commit comments

Comments
 (0)