@@ -60,6 +60,7 @@ public TransformerNeuralNetwork(int numLayers, int numQueries, int numPaths, int
6060 . AddModelElementGroup ( "VB" , new [ ] { 1 , numInputOutputFeatures } , InitializationType . Zeroes ) ;
6161 var inputLayer = inputLayerBuilder . Build ( ) ;
6262 this . inputLayers . Add ( inputLayer ) ;
63+ numInputFeatures = numInputOutputFeatures ;
6364 }
6465
6566 this . nestedLayers = new List < IModelLayer > ( ) ;
@@ -73,6 +74,7 @@ public TransformerNeuralNetwork(int numLayers, int numQueries, int numPaths, int
7374 . AddModelElementGroup ( "QB" , new [ ] { numQueries , 1 , numNestedOutputFeatures } , InitializationType . Zeroes ) ;
7475 var nestedLayer = nestedLayerBuilder . Build ( ) ;
7576 this . nestedLayers . Add ( nestedLayer ) ;
77+ numNestedFeatures = numNestedOutputFeatures ;
7678 outputFeaturesList . Add ( numNestedOutputFeatures * numQueries ) ;
7779 }
7880
@@ -85,18 +87,11 @@ public TransformerNeuralNetwork(int numLayers, int numQueries, int numPaths, int
8587 . AddModelElementGroup ( "F2W" , new [ ] { outputFeaturesList [ i ] , ( outputFeaturesList [ i ] / 2 ) } , InitializationType . Xavier )
8688 . AddModelElementGroup ( "F2B" , new [ ] { 1 , ( outputFeaturesList [ i ] / 2 ) } , InitializationType . Xavier )
8789 . AddModelElementGroup ( "Beta" , new [ ] { 1 , 1 } , InitializationType . He ) ;
88- if ( i < ( this . NumLayers - 1 ) )
89- {
90- outputLayerBuilder
91- . AddModelElementGroup ( "R" , new [ ] { ( outputFeaturesList [ i ] / 2 ) , this . NumFeatures } , InitializationType . Xavier )
92- . AddModelElementGroup ( "RB" , new [ ] { 1 , this . NumFeatures } , InitializationType . Zeroes ) ;
93- }
94- else
95- {
90+
9691 outputLayerBuilder
9792 . AddModelElementGroup ( "R" , new [ ] { ( outputFeaturesList [ i ] / 2 ) , this . NumFeatures * 2 } , InitializationType . Xavier )
9893 . AddModelElementGroup ( "RB" , new [ ] { 1 , this . NumFeatures * 2 } , InitializationType . Zeroes ) ;
99- }
94+
10095 var outputLayer = outputLayerBuilder . Build ( ) ;
10196 this . outputLayers . Add ( outputLayer ) ;
10297 }
0 commit comments