Skip to content

Commit 77c4d9f

Browse files
committed
Update training set loader.
1 parent 773f653 commit 77c4d9f

File tree

3 files changed

+126
-28
lines changed

3 files changed

+126
-28
lines changed

examples/gnn/ParallelReverseAutoDiff.GnnExample/GraphAttentionPaths/GraphAttentionPathsNeuralNetwork.cs

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ public class GraphAttentionPathsNeuralNetwork
3333
private readonly int numQueries;
3434
private readonly int alphabetSize;
3535
private readonly int embeddingSize;
36-
private readonly int batchSize;
3736
private readonly double learningRate;
3837
private readonly double clipValue;
3938
private readonly Dictionary<int, Guid> typeToIdMap;
@@ -48,15 +47,14 @@ public class GraphAttentionPathsNeuralNetwork
4847
/// Initializes a new instance of the <see cref="GraphAttentionPathsNeuralNetwork"/> class.
4948
/// </summary>
5049
/// <param name="graphs">The graphs.</param>
51-
/// <param name="batchSize">The batch size.</param>
5250
/// <param name="numIndices">The number of indices.</param>
5351
/// <param name="alphabetSize">The alphabet size.</param>
5452
/// <param name="embeddingSize">The embedding size.</param>
5553
/// <param name="numLayers">The number of layers.</param>
5654
/// <param name="numQueries">The number of queries.</param>
5755
/// <param name="learningRate">The learning rate.</param>
5856
/// <param name="clipValue">The clip Value.</param>
59-
public GraphAttentionPathsNeuralNetwork(List<GapGraph> graphs, int batchSize, int numIndices, int alphabetSize, int embeddingSize, int numLayers, int numQueries, double learningRate, double clipValue)
57+
public GraphAttentionPathsNeuralNetwork(List<GapGraph> graphs, int numIndices, int alphabetSize, int embeddingSize, int numLayers, int numQueries, double learningRate, double clipValue)
6058
{
6159
this.gapGraphs = graphs;
6260
this.numFeatures = (numIndices * embeddingSize) + 3;
@@ -65,21 +63,75 @@ public GraphAttentionPathsNeuralNetwork(List<GapGraph> graphs, int batchSize, in
6563
this.numIndices = numIndices;
6664
this.numLayers = numLayers;
6765
this.numQueries = numQueries;
68-
this.batchSize = batchSize;
6966
this.learningRate = learningRate;
7067
this.clipValue = clipValue;
7168
this.modelLayers = new List<IModelLayer>();
7269
this.embeddingNeuralNetwork = new List<EmbeddingNeuralNetwork>();
7370
this.edgeAttentionNeuralNetwork = new List<EdgeAttentionNeuralNetwork>();
71+
this.transformerNeuralNetwork = new List<TransformerNeuralNetwork>();
72+
this.attentionMessagePassingNeuralNetwork = new List<AttentionMessagePassingNeuralNetwork>();
73+
this.gcnNeuralNetwork = new GcnNeuralNetwork(numLayers, 4, this.numFeatures, learningRate, clipValue);
74+
this.readoutNeuralNetwork = new ReadoutNeuralNetwork(numLayers, numQueries, 4, this.numFeatures, learningRate, clipValue);
7475
this.typeToIdMap = new Dictionary<int, Guid>();
7576
this.typeToIdMapTransformer = new Dictionary<int, Guid>();
7677
this.typeToIdMapAttention = new Dictionary<int, Guid>();
7778
this.typeToIdMapEmbeddings = new Dictionary<int, Guid>();
7879
this.connectedPathsMap = new Dictionary<GapPath, List<GapPath>>();
79-
this.transformerNeuralNetwork = new List<TransformerNeuralNetwork>();
80-
this.attentionMessagePassingNeuralNetwork = new List<AttentionMessagePassingNeuralNetwork>();
81-
this.gcnNeuralNetwork = new GcnNeuralNetwork(numLayers, 4, this.numFeatures, learningRate, clipValue);
82-
this.readoutNeuralNetwork = new ReadoutNeuralNetwork(numLayers, numQueries, 4, this.numFeatures, learningRate, clipValue);
80+
}
81+
82+
/// <summary>
83+
/// Reset the network.
84+
/// </summary>
85+
/// <returns>A task.</returns>
86+
public async Task Reset()
87+
{
88+
this.typeToIdMap.Clear();
89+
this.typeToIdMapTransformer.Clear();
90+
this.typeToIdMapAttention.Clear();
91+
this.typeToIdMapEmbeddings.Clear();
92+
this.connectedPathsMap.Clear();
93+
94+
for (int i = 0; i < 7; ++i)
95+
{
96+
await this.embeddingNeuralNetwork[i].Initialize();
97+
this.embeddingNeuralNetwork[i].Parameters.AdamIteration++;
98+
}
99+
100+
for (int i = 0; i < 7; ++i)
101+
{
102+
await this.edgeAttentionNeuralNetwork[i].Initialize();
103+
this.edgeAttentionNeuralNetwork[i].Parameters.AdamIteration++;
104+
}
105+
106+
for (int i = 0; i < 7; ++i)
107+
{
108+
await this.transformerNeuralNetwork[i].Initialize();
109+
this.transformerNeuralNetwork[i].Parameters.AdamIteration++;
110+
}
111+
112+
for (int i = 0; i < 7; ++i)
113+
{
114+
await this.attentionMessagePassingNeuralNetwork[i].Initialize();
115+
this.attentionMessagePassingNeuralNetwork[i].Parameters.AdamIteration++;
116+
}
117+
118+
await this.gcnNeuralNetwork.Initialize();
119+
this.gcnNeuralNetwork.Parameters.AdamIteration++;
120+
121+
await this.readoutNeuralNetwork.Initialize();
122+
this.readoutNeuralNetwork.Parameters.AdamIteration++;
123+
124+
GC.Collect(GC.MaxGeneration, GCCollectionMode.Forced, true);
125+
}
126+
127+
/// <summary>
128+
/// Reinitialize with new graphs.
129+
/// </summary>
130+
/// <param name="graphs">The graphs.</param>
131+
public void Reinitialize(List<GapGraph> graphs)
132+
{
133+
this.gapGraphs.Clear();
134+
this.gapGraphs.AddRange(graphs);
83135
}
84136

85137
/// <summary>
@@ -136,7 +188,7 @@ public async Task Initialize()
136188
public void SaveWeights()
137189
{
138190
Guid guid = Guid.NewGuid();
139-
var dir = $"E:\\store\\{guid}";
191+
var dir = $"E:\\store\\{guid}_{this.readoutNeuralNetwork.Parameters.AdamIteration}";
140192
Directory.CreateDirectory(dir);
141193
int index = 0;
142194
foreach (var modelLayer in this.modelLayers)
@@ -151,7 +203,7 @@ public void SaveWeights()
151203
/// </summary>
152204
public void ApplyWeights()
153205
{
154-
var guid = "b93e5314-3a40-4353-8b44-0795cbfe0d4e";
206+
var guid = "78362272-3112-49ab-8e7f-a95bcccc4f1f";
155207
var dir = $"E:\\store\\{guid}";
156208
for (int i = 0; i < this.modelLayers.Count; ++i)
157209
{
@@ -162,6 +214,19 @@ public void ApplyWeights()
162214
}
163215
}
164216

217+
/// <summary>
218+
/// Apply the gradients to update the weights.
219+
/// </summary>
220+
public void ApplyGradients()
221+
{
222+
var clipper = this.readoutNeuralNetwork.Utilities.GradientClipper;
223+
clipper.Clip(this.modelLayers.ToArray());
224+
var adamOptimizer = this.readoutNeuralNetwork.Utilities.AdamOptimizer;
225+
adamOptimizer.Optimize(this.modelLayers.ToArray());
226+
GradientClearer clearer = new GradientClearer();
227+
clearer.Clear(this.modelLayers.ToArray());
228+
}
229+
165230
/// <summary>
166231
/// Make a forward pass through the computation graph.
167232
/// </summary>

examples/gnn/ParallelReverseAutoDiff.GnnExample/ParallelReverseAutoDiff.GnnExample.xml

Lines changed: 12 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/gnn/ParallelReverseAutoDiff.GnnExample/TrainingSetLoader.cs

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace ParallelReverseAutoDiff.GnnExample
1616
public class TrainingSetLoader
1717
{
1818
private Random rand;
19+
private GraphAttentionPathsNeuralNetwork neuralNetwork;
1920

2021
/// <summary>
2122
/// Initializes a new instance of the <see cref="TrainingSetLoader"/> class.
@@ -32,29 +33,51 @@ public TrainingSetLoader()
3233
public async Task LoadMiniBatch()
3334
{
3435
var graphFiles = Directory.GetFiles("G:\\My Drive\\graphs", "*.zip").ToList();
35-
var randomGraphFiles = graphFiles.OrderBy(x => this.rand.Next()).Take(4).ToList();
36-
List<GapGraph> graphs = new List<GapGraph>();
37-
for (int i = 0; i < randomGraphFiles.Count; ++i)
36+
37+
for (int i = 0; i < 10; ++i)
3838
{
39-
var file = randomGraphFiles[i];
40-
var jsons = this.ExtractFromZip(file);
41-
var randomJson = jsons.OrderBy(x => this.rand.Next()).First();
42-
var graph = JsonConvert.DeserializeObject<GapGraph>(randomJson) ?? throw new InvalidOperationException("Could not deserialize to graph.");
43-
graph.Populate();
44-
graphs.Add(graph);
39+
var randomGraphFiles = graphFiles.OrderBy(x => this.rand.Next()).Take(4).ToList();
40+
List<GapGraph> graphs = new List<GapGraph>();
41+
for (int j = 0; j < randomGraphFiles.Count; ++j)
42+
{
43+
var file = randomGraphFiles[j];
44+
var jsons = this.ExtractFromZip(file);
45+
var randomJson = jsons.OrderBy(x => this.rand.Next()).First();
46+
var graph = JsonConvert.DeserializeObject<GapGraph>(randomJson) ?? throw new InvalidOperationException("Could not deserialize to graph.");
47+
graph.Populate();
48+
graphs.Add(graph);
49+
}
50+
51+
var json = JsonConvert.SerializeObject(graphs);
52+
File.WriteAllText("minibatch.json", json);
53+
54+
await this.ProcessMiniBatch(graphs);
55+
Thread.Sleep(5000);
4556
}
4657

47-
var json = JsonConvert.SerializeObject(graphs);
48-
File.WriteAllText("minibatch.json", json);
58+
this.neuralNetwork.SaveWeights();
59+
}
4960

50-
int batchSize = 4;
61+
private async Task ProcessMiniBatch(List<GapGraph> graphs)
62+
{
5163
try
5264
{
5365
CudaBlas.Instance.Initialize();
54-
GraphAttentionPathsNeuralNetwork neuralNetwork = new GraphAttentionPathsNeuralNetwork(graphs, batchSize, 16, 115, 5, 2, 4, 0.001d, 4d);
55-
await neuralNetwork.Initialize();
56-
DeepMatrix gradientOfLoss = neuralNetwork.Forward();
57-
await neuralNetwork.Backward(gradientOfLoss);
66+
if (this.neuralNetwork == null)
67+
{
68+
this.neuralNetwork = new GraphAttentionPathsNeuralNetwork(graphs, 16, 115, 7, 2, 4, 0.001d, 4d);
69+
await this.neuralNetwork.Initialize();
70+
this.neuralNetwork.ApplyWeights();
71+
}
72+
else
73+
{
74+
this.neuralNetwork.Reinitialize(graphs);
75+
}
76+
77+
DeepMatrix gradientOfLoss = this.neuralNetwork.Forward();
78+
await this.neuralNetwork.Backward(gradientOfLoss);
79+
this.neuralNetwork.ApplyGradients();
80+
await this.neuralNetwork.Reset();
5881
}
5982
finally
6083
{

0 commit comments

Comments
 (0)