@@ -7,12 +7,12 @@ namespace ParallelReverseAutoDiff.Test.GraphAttentionPaths
77{
88 using System ;
99 using System . IO ;
10- using ParallelReverseAutoDiff . GnnExample . Common ;
1110 using ParallelReverseAutoDiff . RMAD ;
1211 using ParallelReverseAutoDiff . Test . GraphAttentionPaths . AttentionMessagePassing ;
1312 using ParallelReverseAutoDiff . Test . GraphAttentionPaths . EdgeAttention ;
1413 using ParallelReverseAutoDiff . Test . GraphAttentionPaths . Embedding ;
1514 using ParallelReverseAutoDiff . Test . GraphAttentionPaths . GCN ;
15+ using ParallelReverseAutoDiff . Test . GraphAttentionPaths . Transformer ;
1616
1717 /// <summary>
1818 /// Graph Attention Paths Neural Network.
@@ -22,7 +22,7 @@ public class GraphAttentionPathsNeuralNetwork
2222 private const string WEIGHTSSAVEPATH = "D:\\ models\\ initialWeights2.json" ;
2323 private readonly List < EmbeddingNeuralNetwork > embeddingNeuralNetwork ;
2424 private readonly List < EdgeAttentionNeuralNetwork > edgeAttentionNeuralNetwork ;
25- private readonly List < LstmNeuralNetwork > lstmNeuralNetwork ;
25+ private readonly List < TransformerNeuralNetwork > transformerNeuralNetwork ;
2626 private readonly List < AttentionMessagePassingNeuralNetwork > attentionMessagePassingNeuralNetwork ;
2727 private readonly GcnNeuralNetwork gcnNeuralNetwork ;
2828 private readonly ReadoutNeuralNetwork readoutNeuralNetwork ;
@@ -37,7 +37,7 @@ public class GraphAttentionPathsNeuralNetwork
3737 private readonly double learningRate ;
3838 private readonly double clipValue ;
3939 private readonly Dictionary < int , Guid > typeToIdMap ;
40- private readonly Dictionary < int , Guid > typeToIdMapLstm ;
40+ private readonly Dictionary < int , Guid > typeToIdMapTransformer ;
4141 private readonly Dictionary < int , Guid > typeToIdMapAttention ;
4242 private readonly Dictionary < int , Guid > typeToIdMapEmbeddings ;
4343 private readonly Dictionary < GapPath , List < GapPath > > connectedPathsMap ;
@@ -72,11 +72,11 @@ public GraphAttentionPathsNeuralNetwork(List<GapGraph> graphs, int batchSize, in
7272 this . embeddingNeuralNetwork = new List < EmbeddingNeuralNetwork > ( ) ;
7373 this . edgeAttentionNeuralNetwork = new List < EdgeAttentionNeuralNetwork > ( ) ;
7474 this . typeToIdMap = new Dictionary < int , Guid > ( ) ;
75- this . typeToIdMapLstm = new Dictionary < int , Guid > ( ) ;
75+ this . typeToIdMapTransformer = new Dictionary < int , Guid > ( ) ;
7676 this . typeToIdMapAttention = new Dictionary < int , Guid > ( ) ;
7777 this . typeToIdMapEmbeddings = new Dictionary < int , Guid > ( ) ;
7878 this . connectedPathsMap = new Dictionary < GapPath , List < GapPath > > ( ) ;
79- this . lstmNeuralNetwork = new List < LstmNeuralNetwork > ( ) ;
79+ this . transformerNeuralNetwork = new List < TransformerNeuralNetwork > ( ) ;
8080 this . attentionMessagePassingNeuralNetwork = new List < AttentionMessagePassingNeuralNetwork > ( ) ;
8181 this . gcnNeuralNetwork = new GcnNeuralNetwork ( numLayers , 4 , this . numFeatures , learningRate , clipValue ) ;
8282 this . readoutNeuralNetwork = new ReadoutNeuralNetwork ( numLayers , numQueries , 4 , this . numFeatures , learningRate , clipValue ) ;
@@ -106,10 +106,10 @@ public async Task Initialize()
106106
107107 for ( int i = 0 ; i < 7 ; ++ i )
108108 {
109- var model = new LstmNeuralNetwork ( this . numFeatures * ( int ) Math . Pow ( 2d , ( double ) this . numLayers ) , 500 , this . numFeatures * ( int ) Math . Pow ( 2d , ( double ) this . numLayers ) * 2 , i + 2 , this . numLayers , this . learningRate , this . clipValue ) ;
110- this . lstmNeuralNetwork . Add ( model ) ;
111- await this . lstmNeuralNetwork [ i ] . Initialize ( ) ;
112- this . modelLayers = this . modelLayers . Concat ( this . lstmNeuralNetwork [ i ] . ModelLayers ) . ToList ( ) ;
109+ var model = new TransformerNeuralNetwork ( this . numLayers , this . numQueries / 2 , i + 2 , this . numFeatures * ( int ) Math . Pow ( 2d , ( double ) this . numLayers ) , i + 2 , this . learningRate , this . clipValue ) ;
110+ this . transformerNeuralNetwork . Add ( model ) ;
111+ await this . transformerNeuralNetwork [ i ] . Initialize ( ) ;
112+ this . modelLayers = this . modelLayers . Concat ( this . transformerNeuralNetwork [ i ] . ModelLayers ) . ToList ( ) ;
113113 }
114114
115115 for ( int i = 0 ; i < 7 ; ++ i )
@@ -135,27 +135,31 @@ public async Task Initialize()
135135 /// </summary>
136136 public void SaveWeights ( )
137137 {
138- var weightStore = new WeightStore ( ) ;
139- weightStore . AddRange ( this . modelLayers ) ;
140- weightStore . Save ( new FileInfo ( WEIGHTSSAVEPATH ) ) ;
138+ Guid guid = Guid . NewGuid ( ) ;
139+ var dir = $ "E:\\ store\\ { guid } ";
140+ Directory . CreateDirectory ( dir ) ;
141+ int index = 0 ;
142+ foreach ( var modelLayer in this . modelLayers )
143+ {
144+ modelLayer . SaveWeightsAndMoments ( new FileInfo ( $ "{ dir } \\ layer{ index } .json") ) ;
145+ index ++ ;
146+ }
141147 }
142148
143149 /// <summary>
144150 /// Apply the weights from the save path.
145151 /// </summary>
146152 public void ApplyWeights ( )
147153 {
148- var weightStore = new WeightStore ( ) ;
149- weightStore . Load ( new FileInfo ( WEIGHTSSAVEPATH ) ) ;
154+ var guid = "b93e5314-3a40-4353-8b44-0795cbfe0d4e" ;
155+ var dir = $ "E: \\ store \\ { guid } " ;
150156 for ( int i = 0 ; i < this . modelLayers . Count ; ++ i )
151157 {
152158 var modelLayer = this . modelLayers [ i ] ;
153- var weights = weightStore . ToModelLayerWeights ( i ) ;
154- modelLayer . ApplyWeights ( weights ) ;
159+ var file = new FileInfo ( $ "{ dir } \\ layer{ i } .json") ;
160+ modelLayer . LoadWeightsAndMoments ( file ) ;
161+ GC . Collect ( GC . MaxGeneration , GCCollectionMode . Forced , true ) ;
155162 }
156-
157- weightStore = null ;
158- GC . Collect ( GC . MaxGeneration , GCCollectionMode . Forced ) ;
159163 }
160164
161165 /// <summary>
@@ -279,26 +283,26 @@ public DeepMatrix Forward()
279283 }
280284 }
281285
282- Dictionary < int , List < DeepMatrix > > inputsByLength = new Dictionary < int , List < DeepMatrix > > ( ) ;
286+ Dictionary < int , List < Matrix > > inputsByLength = new Dictionary < int , List < Matrix > > ( ) ;
283287 Dictionary < ( int Length , int Index ) , GapPath > pathIndexMap = new Dictionary < ( int Length , int Index ) , GapPath > ( ) ;
284288
285289 foreach ( var graph in this . gapGraphs )
286290 {
287291 foreach ( var path in graph . GapPaths )
288292 {
289293 var pathLength = path . Nodes . Count ;
290- var input = new DeepMatrix ( pathLength , this . numFeatures * ( int ) Math . Pow ( 2d , ( double ) this . numLayers ) , 1 ) ;
291- for ( int i = 0 ; i < input . Depth ; ++ i )
294+ var input = new Matrix ( pathLength , this . numFeatures * ( int ) Math . Pow ( 2d , ( double ) this . numLayers ) ) ;
295+ for ( int i = 0 ; i < input . Rows ; ++ i )
292296 {
293- for ( int j = 0 ; j < input . Rows ; ++ j )
297+ for ( int j = 0 ; j < input . Cols ; ++ j )
294298 {
295- input [ i ] [ j ] [ 0 ] = path . Nodes [ i ] . FeatureVector [ j ] [ 0 ] ;
299+ input [ i ] [ j ] = path . Nodes [ i ] . FeatureVector [ j ] [ 0 ] ;
296300 }
297301 }
298302
299303 if ( ! inputsByLength . ContainsKey ( pathLength ) )
300304 {
301- inputsByLength [ pathLength ] = new List < DeepMatrix > ( ) ;
305+ inputsByLength [ pathLength ] = new List < Matrix > ( ) ;
302306 }
303307
304308 inputsByLength [ pathLength ] . Add ( input ) ;
@@ -310,15 +314,14 @@ public DeepMatrix Forward()
310314 foreach ( var length in inputsByLength . Keys )
311315 {
312316 var batchedInput = inputsByLength [ length ] . ToArray ( ) ; // Array of DeepMatrix where each DeepMatrix is a timestep for all sequences in the batch
313- var switched = CommonMatrixUtils . SwitchFirstTwoDimensions ( batchedInput ) ;
314- var lstmNet = this . lstmNeuralNetwork [ length - 2 ] ; // Because a path must have a length of at least two
315- lstmNet . Parameters . BatchSize = batchedInput . Length ;
316- lstmNet . InitializeState ( ) ;
317- lstmNet . AutomaticForwardPropagate ( new FourDimensionalMatrix ( switched ) ) ;
317+ var transformerNet = this . transformerNeuralNetwork [ length - 2 ] ; // Because a path must have a length of at least two
318+ transformerNet . Parameters . BatchSize = batchedInput . Length ;
319+ transformerNet . InitializeState ( ) ;
320+ transformerNet . AutomaticForwardPropagate ( new DeepMatrix ( batchedInput ) ) ;
318321 var id = Guid . NewGuid ( ) ;
319- this . typeToIdMapLstm . Add ( length , id ) ;
320- lstmNet . StoreOperationIntermediates ( id ) ;
321- var output = lstmNet . OutputPathFeatures [ length - 1 ] ;
322+ this . typeToIdMapTransformer . Add ( length , id ) ;
323+ transformerNet . StoreOperationIntermediates ( id ) ;
324+ var output = transformerNet . Output ;
322325 for ( int i = 0 ; i < output . Depth ; ++ i )
323326 {
324327 var path = pathIndexMap [ ( length , i ) ] ;
@@ -522,7 +525,7 @@ public async Task Backward(DeepMatrix gradientOfLossWrtReadoutOutput)
522525 this . ApplyGradients ( pathToGradientsMap ) ;
523526
524527 Dictionary < int , List < Matrix > > pathLengthToGradientMap = new Dictionary < int , List < Matrix > > ( ) ;
525- Dictionary < ( int , int ) , GapPath > indexesToPathMapLstm = new Dictionary < ( int , int ) , GapPath > ( ) ;
528+ Dictionary < ( int , int ) , GapPath > indexesToPathMapTransformer = new Dictionary < ( int , int ) , GapPath > ( ) ;
526529 foreach ( var graph in this . gapGraphs )
527530 {
528531 foreach ( var path in graph . GapPaths )
@@ -531,12 +534,12 @@ public async Task Backward(DeepMatrix gradientOfLossWrtReadoutOutput)
531534 var gradient = pathToGradientsMap [ path ] . Item1 ;
532535 if ( pathLengthToGradientMap . ContainsKey ( pathLength ) )
533536 {
534- indexesToPathMapLstm . Add ( ( pathLength , pathLengthToGradientMap [ pathLength ] . Count ) , path ) ;
537+ indexesToPathMapTransformer . Add ( ( pathLength , pathLengthToGradientMap [ pathLength ] . Count ) , path ) ;
535538 pathLengthToGradientMap [ pathLength ] . Add ( gradient ) ;
536539 }
537540 else
538541 {
539- indexesToPathMapLstm . Add ( ( pathLength , 0 ) , path ) ;
542+ indexesToPathMapTransformer . Add ( ( pathLength , 0 ) , path ) ;
540543 pathLengthToGradientMap . Add ( pathLength , new List < Matrix > { gradient } ) ;
541544 }
542545 }
@@ -545,24 +548,24 @@ public async Task Backward(DeepMatrix gradientOfLossWrtReadoutOutput)
545548 Dictionary < GapNode , Matrix > nodeToGradientMap = new Dictionary < GapNode , Matrix > ( ) ;
546549 foreach ( var key in pathLengthToGradientMap . Keys )
547550 {
548- var lstmNet = this . lstmNeuralNetwork [ key - 2 ] ;
549- lstmNet . RestoreOperationIntermediates ( this . typeToIdMapLstm [ key ] ) ;
550- var lstmGradient = CommonMatrixUtils . SwitchFirstTwoDimensions ( ( await lstmNet . AutomaticBackwardPropagate ( new DeepMatrix ( pathLengthToGradientMap [ key ] . ToArray ( ) ) ) ) . ToArray ( ) ) ;
551+ var transformerNet = this . transformerNeuralNetwork [ key - 2 ] ;
552+ transformerNet . RestoreOperationIntermediates ( this . typeToIdMapTransformer [ key ] ) ;
553+ var transformerGradient = await transformerNet . AutomaticBackwardPropagate ( new DeepMatrix ( pathLengthToGradientMap [ key ] . ToArray ( ) ) ) ;
551554
552- for ( int i = 0 ; i < lstmGradient . Length ; ++ i )
555+ for ( int i = 0 ; i < transformerGradient . Depth ; ++ i )
553556 {
554- var path = indexesToPathMapLstm [ ( key , i ) ] ;
557+ var path = indexesToPathMapTransformer [ ( key , i ) ] ;
555558 var nodeCount = path . Nodes . Count ;
556559 for ( int j = 0 ; j < nodeCount ; ++ j )
557560 {
558561 var node = path . Nodes [ j ] ;
559562 if ( ! nodeToGradientMap . ContainsKey ( node ) )
560563 {
561- nodeToGradientMap . Add ( node , lstmGradient [ i ] [ j ] ) ;
564+ nodeToGradientMap . Add ( node , new Matrix ( transformerGradient [ i ] [ j ] ) . Transpose ( ) ) ;
562565 }
563566 else
564567 {
565- nodeToGradientMap [ node ] . Accumulate ( lstmGradient [ i ] [ j ] . ToArray ( ) ) ;
568+ nodeToGradientMap [ node ] . Accumulate ( new Matrix ( transformerGradient [ i ] [ j ] ) . Transpose ( ) . ToArray ( ) ) ;
566569 }
567570 }
568571 }
0 commit comments