@@ -33,7 +33,6 @@ public class GraphAttentionPathsNeuralNetwork
3333 private readonly int numQueries ;
3434 private readonly int alphabetSize ;
3535 private readonly int embeddingSize ;
36- private readonly int batchSize ;
3736 private readonly double learningRate ;
3837 private readonly double clipValue ;
3938 private readonly Dictionary < int , Guid > typeToIdMap ;
@@ -48,15 +47,14 @@ public class GraphAttentionPathsNeuralNetwork
4847 /// Initializes a new instance of the <see cref="GraphAttentionPathsNeuralNetwork"/> class.
4948 /// </summary>
5049 /// <param name="graphs">The graphs.</param>
51- /// <param name="batchSize">The batch size.</param>
5250 /// <param name="numIndices">The number of indices.</param>
5351 /// <param name="alphabetSize">The alphabet size.</param>
5452 /// <param name="embeddingSize">The embedding size.</param>
5553 /// <param name="numLayers">The number of layers.</param>
5654 /// <param name="numQueries">The number of queries.</param>
5755 /// <param name="learningRate">The learning rate.</param>
5856 /// <param name="clipValue">The clip Value.</param>
59- public GraphAttentionPathsNeuralNetwork ( List < GapGraph > graphs , int batchSize , int numIndices , int alphabetSize , int embeddingSize , int numLayers , int numQueries , double learningRate , double clipValue )
57+ public GraphAttentionPathsNeuralNetwork ( List < GapGraph > graphs , int numIndices , int alphabetSize , int embeddingSize , int numLayers , int numQueries , double learningRate , double clipValue )
6058 {
6159 this . gapGraphs = graphs ;
6260 this . numFeatures = ( numIndices * embeddingSize ) + 3 ;
@@ -65,21 +63,75 @@ public GraphAttentionPathsNeuralNetwork(List<GapGraph> graphs, int batchSize, in
6563 this . numIndices = numIndices ;
6664 this . numLayers = numLayers ;
6765 this . numQueries = numQueries ;
68- this . batchSize = batchSize ;
6966 this . learningRate = learningRate ;
7067 this . clipValue = clipValue ;
7168 this . modelLayers = new List < IModelLayer > ( ) ;
7269 this . embeddingNeuralNetwork = new List < EmbeddingNeuralNetwork > ( ) ;
7370 this . edgeAttentionNeuralNetwork = new List < EdgeAttentionNeuralNetwork > ( ) ;
71+ this . transformerNeuralNetwork = new List < TransformerNeuralNetwork > ( ) ;
72+ this . attentionMessagePassingNeuralNetwork = new List < AttentionMessagePassingNeuralNetwork > ( ) ;
73+ this . gcnNeuralNetwork = new GcnNeuralNetwork ( numLayers , 4 , this . numFeatures , learningRate , clipValue ) ;
74+ this . readoutNeuralNetwork = new ReadoutNeuralNetwork ( numLayers , numQueries , 4 , this . numFeatures , learningRate , clipValue ) ;
7475 this . typeToIdMap = new Dictionary < int , Guid > ( ) ;
7576 this . typeToIdMapTransformer = new Dictionary < int , Guid > ( ) ;
7677 this . typeToIdMapAttention = new Dictionary < int , Guid > ( ) ;
7778 this . typeToIdMapEmbeddings = new Dictionary < int , Guid > ( ) ;
7879 this . connectedPathsMap = new Dictionary < GapPath , List < GapPath > > ( ) ;
79- this . transformerNeuralNetwork = new List < TransformerNeuralNetwork > ( ) ;
80- this . attentionMessagePassingNeuralNetwork = new List < AttentionMessagePassingNeuralNetwork > ( ) ;
81- this . gcnNeuralNetwork = new GcnNeuralNetwork ( numLayers , 4 , this . numFeatures , learningRate , clipValue ) ;
82- this . readoutNeuralNetwork = new ReadoutNeuralNetwork ( numLayers , numQueries , 4 , this . numFeatures , learningRate , clipValue ) ;
80+ }
81+
82+ /// <summary>
83+ /// Reset the network.
84+ /// </summary>
85+ /// <returns>A task.</returns>
86+ public async Task Reset ( )
87+ {
88+ this . typeToIdMap . Clear ( ) ;
89+ this . typeToIdMapTransformer . Clear ( ) ;
90+ this . typeToIdMapAttention . Clear ( ) ;
91+ this . typeToIdMapEmbeddings . Clear ( ) ;
92+ this . connectedPathsMap . Clear ( ) ;
93+
94+ for ( int i = 0 ; i < 7 ; ++ i )
95+ {
96+ await this . embeddingNeuralNetwork [ i ] . Initialize ( ) ;
97+ this . embeddingNeuralNetwork [ i ] . Parameters . AdamIteration ++ ;
98+ }
99+
100+ for ( int i = 0 ; i < 7 ; ++ i )
101+ {
102+ await this . edgeAttentionNeuralNetwork [ i ] . Initialize ( ) ;
103+ this . edgeAttentionNeuralNetwork [ i ] . Parameters . AdamIteration ++ ;
104+ }
105+
106+ for ( int i = 0 ; i < 7 ; ++ i )
107+ {
108+ await this . transformerNeuralNetwork [ i ] . Initialize ( ) ;
109+ this . transformerNeuralNetwork [ i ] . Parameters . AdamIteration ++ ;
110+ }
111+
112+ for ( int i = 0 ; i < 7 ; ++ i )
113+ {
114+ await this . attentionMessagePassingNeuralNetwork [ i ] . Initialize ( ) ;
115+ this . attentionMessagePassingNeuralNetwork [ i ] . Parameters . AdamIteration ++ ;
116+ }
117+
118+ await this . gcnNeuralNetwork . Initialize ( ) ;
119+ this . gcnNeuralNetwork . Parameters . AdamIteration ++ ;
120+
121+ await this . readoutNeuralNetwork . Initialize ( ) ;
122+ this . readoutNeuralNetwork . Parameters . AdamIteration ++ ;
123+
124+ GC . Collect ( GC . MaxGeneration , GCCollectionMode . Forced , true ) ;
125+ }
126+
127+ /// <summary>
128+ /// Reinitialize with new graphs.
129+ /// </summary>
130+ /// <param name="graphs">The graphs.</param>
131+ public void Reinitialize ( List < GapGraph > graphs )
132+ {
133+ this . gapGraphs . Clear ( ) ;
134+ this . gapGraphs . AddRange ( graphs ) ;
83135 }
84136
85137 /// <summary>
@@ -136,7 +188,7 @@ public async Task Initialize()
136188 public void SaveWeights ( )
137189 {
138190 Guid guid = Guid . NewGuid ( ) ;
139- var dir = $ "E:\\ store\\ { guid } ";
191+ var dir = $ "E:\\ store\\ { guid } _ { this . readoutNeuralNetwork . Parameters . AdamIteration } ";
140192 Directory . CreateDirectory ( dir ) ;
141193 int index = 0 ;
142194 foreach ( var modelLayer in this . modelLayers )
@@ -151,7 +203,7 @@ public void SaveWeights()
151203 /// </summary>
152204 public void ApplyWeights ( )
153205 {
154- var guid = "b93e5314-3a40-4353-8b44-0795cbfe0d4e " ;
206+ var guid = "78362272-3112-49ab-8e7f-a95bcccc4f1f " ;
155207 var dir = $ "E:\\ store\\ { guid } ";
156208 for ( int i = 0 ; i < this . modelLayers . Count ; ++ i )
157209 {
@@ -162,6 +214,19 @@ public void ApplyWeights()
162214 }
163215 }
164216
217+ /// <summary>
218+ /// Apply the gradients to update the weights.
219+ /// </summary>
220+ public void ApplyGradients ( )
221+ {
222+ var clipper = this . readoutNeuralNetwork . Utilities . GradientClipper ;
223+ clipper . Clip ( this . modelLayers . ToArray ( ) ) ;
224+ var adamOptimizer = this . readoutNeuralNetwork . Utilities . AdamOptimizer ;
225+ adamOptimizer . Optimize ( this . modelLayers . ToArray ( ) ) ;
226+ GradientClearer clearer = new GradientClearer ( ) ;
227+ clearer . Clear ( this . modelLayers . ToArray ( ) ) ;
228+ }
229+
165230 /// <summary>
166231 /// Make a forward pass through the computation graph.
167232 /// </summary>
0 commit comments