|
1 | 1 | using Microsoft.VisualStudio.TestTools.UnitTesting;
|
2 | 2 | using System;
|
| 3 | +using System.Linq; |
3 | 4 | using TorchSharp.JIT;
|
4 | 5 | using TorchSharp.Tensor;
|
5 | 6 |
|
@@ -314,6 +315,37 @@ public void TestWeightAndBiasShapeInLinear()
|
314 | 315 | Assert.AreEqual(lin.Bias.Shape[0], 100);
|
315 | 316 | }
|
316 | 317 |
|
| 318 | + [TestMethod] |
| 319 | + public void TestWeightAndBiasParametersInLinear() |
| 320 | + { |
| 321 | + var lin = NN.Module.Linear(1000, 100, true); |
| 322 | + var names = lin.NamedParameters().Select(p => p.name); |
| 323 | + Assert.IsTrue(names.Contains("weight")); |
| 324 | + Assert.IsTrue(names.Contains("bias")); |
| 325 | + } |
| 326 | + |
| 327 | + [TestMethod] |
| 328 | + public void TestWeightParameterInLinear() |
| 329 | + { |
| 330 | + var lin = NN.Module.Linear(1000, 100, false); |
| 331 | + var names = lin.NamedParameters().Select(p => p.name); |
| 332 | + Assert.IsTrue(names.Contains("weight")); |
| 333 | + Assert.IsFalse(names.Contains("bias")); |
| 334 | + } |
| 335 | + |
| 336 | + [TestMethod] |
| 337 | + public void TestWeightAndBiasShapeInLinear3() |
| 338 | + { |
| 339 | + var lin = NN.Module.Linear(1000, 100, true); |
| 340 | + var weight = lin.GetParameter("weight"); |
| 341 | + var bias = lin.GetParameter("bias"); |
| 342 | + Assert.AreEqual(weight.Shape.Length, 2); |
| 343 | + Assert.AreEqual(weight.Shape[0], 100); |
| 344 | + Assert.AreEqual(weight.Shape[1], 1000); |
| 345 | + Assert.AreEqual(bias.Shape.Length, 1); |
| 346 | + Assert.AreEqual(bias.Shape[0], 100); |
| 347 | + } |
| 348 | + |
317 | 349 | [TestMethod]
|
318 | 350 | public void TestLinearWithBias()
|
319 | 351 | {
|
@@ -368,6 +400,7 @@ public void CreateSequence()
|
368 | 400 | var lin2 = NN.Module.Linear(100, 10);
|
369 | 401 | var seq = NN.Module.Sequential(lin1, NN.Module.Relu(), lin2);
|
370 | 402 | var modules = seq.GetModules();
|
| 403 | + Assert.AreEqual(modules.Count(), 3); |
371 | 404 | }
|
372 | 405 |
|
373 | 406 | [TestMethod]
|
@@ -552,6 +585,48 @@ public void TestMul()
|
552 | 585 | }
|
553 | 586 | }
|
554 | 587 |
|
| 588 | + [TestMethod] |
| 589 | + public void TestCustomModule() |
| 590 | + { |
| 591 | + var module = new TestModule("test", FloatTensor.RandomN(new long[] { 2, 2 }), true); |
| 592 | + var name = module.GetName(); |
| 593 | + Assert.IsNotNull(name); |
| 594 | + Assert.IsTrue(module.HasParameter("test")); |
| 595 | + } |
| 596 | + |
| 597 | + [TestMethod] |
| 598 | + public void TestCustomModuleWithInPlaceModification() |
| 599 | + { |
| 600 | + var param = FloatTensor.RandomN(new long[] { 1000, 100 }); |
| 601 | + var module = new TestModule("test", param, true); |
| 602 | + |
| 603 | + Assert.AreEqual(module.GetParameter("test").Shape[0], 1000); |
| 604 | + Assert.AreEqual(module.GetParameter("test").Shape[1], 100); |
| 605 | + |
| 606 | + using (var grad = new AutoGradMode(false)) |
| 607 | + { |
| 608 | + param.TransposeInPlace(0, 1); |
| 609 | + } |
| 610 | + Assert.AreEqual(module.GetParameter("test").Shape[0], 100); |
| 611 | + Assert.AreEqual(module.GetParameter("test").Shape[1], 1000); |
| 612 | + Assert.AreEqual(param.Shape[0], 100); |
| 613 | + Assert.AreEqual(param.Shape[1], 1000); |
| 614 | + } |
| 615 | + |
| 616 | + private class TestModule : NN.Module |
| 617 | + { |
| 618 | + public TestModule(string name, ITorchTensor<float> tensor, bool withGrad) |
| 619 | + : base((name, tensor, withGrad)) |
| 620 | + { |
| 621 | + } |
| 622 | + |
| 623 | + public override ITorchTensor<float> Forward<T>(params ITorchTensor<T>[] tensors) |
| 624 | + { |
| 625 | + throw new NotImplementedException(); |
| 626 | + } |
| 627 | + } |
| 628 | + |
| 629 | + |
555 | 630 | /// <summary>
|
556 | 631 | /// Fully connected Relu net with one hidden layer trained using gradient descent.
|
557 | 632 | /// Taken from <see cref="https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_nn.html"/>.
|
|
0 commit comments