1
1
using Microsoft . VisualStudio . TestTools . UnitTesting ;
2
2
using System ;
3
+ using System . Linq ;
3
4
using TorchSharp . JIT ;
4
5
using TorchSharp . Tensor ;
5
6
@@ -302,6 +303,88 @@ public void TestSetGetBiasInLinear()
302
303
Assert . AreEqual ( lin . Bias . NumberOfElements , bias . NumberOfElements ) ;
303
304
}
304
305
306
+ [ TestMethod ]
307
+ public void TestWeightAndBiasShapeInLinear ( )
308
+ {
309
+ var lin = NN . Module . Linear ( 1000 , 100 , true ) ;
310
+
311
+ Assert . AreEqual ( lin . Weight . Shape . Length , 2 ) ;
312
+ Assert . AreEqual ( lin . Weight . Shape [ 0 ] , 100 ) ;
313
+ Assert . AreEqual ( lin . Weight . Shape [ 1 ] , 1000 ) ;
314
+ Assert . AreEqual ( lin . Bias . Shape . Length , 1 ) ;
315
+ Assert . AreEqual ( lin . Bias . Shape [ 0 ] , 100 ) ;
316
+ }
317
+
318
+ [ TestMethod ]
319
+ public void TestWeightAndBiasParametersInLinear ( )
320
+ {
321
+ var lin = NN . Module . Linear ( 1000 , 100 , true ) ;
322
+ var names = lin . NamedParameters ( ) . Select ( p => p . name ) ;
323
+ Assert . IsTrue ( names . Contains ( "weight" ) ) ;
324
+ Assert . IsTrue ( names . Contains ( "bias" ) ) ;
325
+ }
326
+
327
+ [ TestMethod ]
328
+ public void TestWeightParameterInLinear ( )
329
+ {
330
+ var lin = NN . Module . Linear ( 1000 , 100 , false ) ;
331
+ var names = lin . NamedParameters ( ) . Select ( p => p . name ) ;
332
+ Assert . IsTrue ( names . Contains ( "weight" ) ) ;
333
+ Assert . IsFalse ( names . Contains ( "bias" ) ) ;
334
+ }
335
+
336
+ [ TestMethod ]
337
+ public void TestWeightAndBiasShapeInLinear3 ( )
338
+ {
339
+ var lin = NN . Module . Linear ( 1000 , 100 , true ) ;
340
+ var weight = lin . GetParameter ( "weight" ) ;
341
+ var bias = lin . GetParameter ( "bias" ) ;
342
+ Assert . AreEqual ( weight . Shape . Length , 2 ) ;
343
+ Assert . AreEqual ( weight . Shape [ 0 ] , 100 ) ;
344
+ Assert . AreEqual ( weight . Shape [ 1 ] , 1000 ) ;
345
+ Assert . AreEqual ( bias . Shape . Length , 1 ) ;
346
+ Assert . AreEqual ( bias . Shape [ 0 ] , 100 ) ;
347
+ }
348
+
349
+ [ TestMethod ]
350
+ public void TestLinearWithBias ( )
351
+ {
352
+ var lin = NN . Module . Linear ( 1000 , 100 , true ) ;
353
+ var bias = lin . Bias ;
354
+ var weight = lin . Weight . T ( ) ;
355
+ var input = FloatTensor . RandomN ( new long [ ] { 1 , 1000 } ) ;
356
+ var forward = lin . Forward ( input ) ;
357
+ var matmul = input . MatMul ( weight ) . Add ( bias ) ;
358
+
359
+ Assert . AreEqual ( forward . Shape . Length , matmul . Shape . Length ) ;
360
+ Assert . AreEqual ( forward . Shape [ 0 ] , matmul . Shape [ 0 ] ) ;
361
+ Assert . AreEqual ( forward . Shape [ 1 ] , matmul . Shape [ 1 ] ) ;
362
+
363
+ for ( int i = 0 ; i < 100 ; i ++ )
364
+ {
365
+ Assert . AreEqual ( forward . Data [ i ] , matmul . Data [ i ] ) ;
366
+ }
367
+ }
368
+
369
+ [ TestMethod ]
370
+ public void TestLinearNoBias ( )
371
+ {
372
+ var lin = NN . Module . Linear ( 1000 , 100 , false ) ;
373
+ var weight = lin . Weight . Transpose ( 0 , 1 ) ;
374
+ var input = FloatTensor . RandomN ( new long [ ] { 1 , 1000 } ) ;
375
+ var forward = lin . Forward ( input ) ;
376
+ var matmul = input . MatMul ( weight ) ;
377
+
378
+ Assert . AreEqual ( forward . Shape . Length , matmul . Shape . Length ) ;
379
+ Assert . AreEqual ( forward . Shape [ 0 ] , matmul . Shape [ 0 ] ) ;
380
+ Assert . AreEqual ( forward . Shape [ 1 ] , matmul . Shape [ 1 ] ) ;
381
+
382
+ for ( int i = 0 ; i < 100 ; i ++ )
383
+ {
384
+ Assert . AreEqual ( forward . Data [ i ] , matmul . Data [ i ] ) ;
385
+ }
386
+ }
387
+
305
388
[ TestMethod ]
306
389
public void CreateRelu ( )
307
390
{
@@ -317,6 +400,7 @@ public void CreateSequence()
317
400
var lin2 = NN . Module . Linear ( 100 , 10 ) ;
318
401
var seq = NN . Module . Sequential ( lin1 , NN . Module . Relu ( ) , lin2 ) ;
319
402
var modules = seq . GetModules ( ) ;
403
+ Assert . AreEqual ( modules . Count ( ) , 3 ) ;
320
404
}
321
405
322
406
[ TestMethod ]
@@ -501,6 +585,48 @@ public void TestMul()
501
585
}
502
586
}
503
587
588
+ [ TestMethod ]
589
+ public void TestCustomModule ( )
590
+ {
591
+ var module = new TestModule ( "test" , FloatTensor . RandomN ( new long [ ] { 2 , 2 } ) , true ) ;
592
+ var name = module . GetName ( ) ;
593
+ Assert . IsNotNull ( name ) ;
594
+ Assert . IsTrue ( module . HasParameter ( "test" ) ) ;
595
+ }
596
+
597
+ [ TestMethod ]
598
+ public void TestCustomModuleWithInPlaceModification ( )
599
+ {
600
+ var param = FloatTensor . RandomN ( new long [ ] { 1000 , 100 } ) ;
601
+ var module = new TestModule ( "test" , param , true ) ;
602
+
603
+ Assert . AreEqual ( module . GetParameter ( "test" ) . Shape [ 0 ] , 1000 ) ;
604
+ Assert . AreEqual ( module . GetParameter ( "test" ) . Shape [ 1 ] , 100 ) ;
605
+
606
+ using ( var grad = new AutoGradMode ( false ) )
607
+ {
608
+ param . TransposeInPlace ( 0 , 1 ) ;
609
+ }
610
+ Assert . AreEqual ( module . GetParameter ( "test" ) . Shape [ 0 ] , 100 ) ;
611
+ Assert . AreEqual ( module . GetParameter ( "test" ) . Shape [ 1 ] , 1000 ) ;
612
+ Assert . AreEqual ( param . Shape [ 0 ] , 100 ) ;
613
+ Assert . AreEqual ( param . Shape [ 1 ] , 1000 ) ;
614
+ }
615
+
616
+ private class TestModule : NN . Module
617
+ {
618
+ public TestModule ( string name , ITorchTensor < float > tensor , bool withGrad )
619
+ : base ( ( name , tensor , withGrad ) )
620
+ {
621
+ }
622
+
623
+ public override ITorchTensor < float > Forward < T > ( params ITorchTensor < T > [ ] tensors )
624
+ {
625
+ throw new NotImplementedException ( ) ;
626
+ }
627
+ }
628
+
629
+
504
630
/// <summary>
505
631
/// Fully connected Relu net with one hidden layer trained using gradient descent.
506
632
/// Taken from <see cref="https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_nn.html"/>.
0 commit comments