Skip to content

Commit ad3beef

Browse files
committed
Add transformer.
1 parent d0ecde6 commit ad3beef

File tree

8 files changed

+774
-52
lines changed

8 files changed

+774
-52
lines changed

examples/gnn/ParallelReverseAutoDiff.GnnExample/GraphAttentionPaths/GraphAttentionPathsNeuralNetwork.cs

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ namespace ParallelReverseAutoDiff.Test.GraphAttentionPaths
77
{
88
using System;
99
using System.IO;
10-
using ParallelReverseAutoDiff.GnnExample.Common;
1110
using ParallelReverseAutoDiff.RMAD;
1211
using ParallelReverseAutoDiff.Test.GraphAttentionPaths.AttentionMessagePassing;
1312
using ParallelReverseAutoDiff.Test.GraphAttentionPaths.EdgeAttention;
1413
using ParallelReverseAutoDiff.Test.GraphAttentionPaths.Embedding;
1514
using ParallelReverseAutoDiff.Test.GraphAttentionPaths.GCN;
15+
using ParallelReverseAutoDiff.Test.GraphAttentionPaths.Transformer;
1616

1717
/// <summary>
1818
/// Graph Attention Paths Neural Network.
@@ -22,7 +22,7 @@ public class GraphAttentionPathsNeuralNetwork
2222
private const string WEIGHTSSAVEPATH = "D:\\models\\initialWeights2.json";
2323
private readonly List<EmbeddingNeuralNetwork> embeddingNeuralNetwork;
2424
private readonly List<EdgeAttentionNeuralNetwork> edgeAttentionNeuralNetwork;
25-
private readonly List<LstmNeuralNetwork> lstmNeuralNetwork;
25+
private readonly List<TransformerNeuralNetwork> transformerNeuralNetwork;
2626
private readonly List<AttentionMessagePassingNeuralNetwork> attentionMessagePassingNeuralNetwork;
2727
private readonly GcnNeuralNetwork gcnNeuralNetwork;
2828
private readonly ReadoutNeuralNetwork readoutNeuralNetwork;
@@ -37,7 +37,7 @@ public class GraphAttentionPathsNeuralNetwork
3737
private readonly double learningRate;
3838
private readonly double clipValue;
3939
private readonly Dictionary<int, Guid> typeToIdMap;
40-
private readonly Dictionary<int, Guid> typeToIdMapLstm;
40+
private readonly Dictionary<int, Guid> typeToIdMapTransformer;
4141
private readonly Dictionary<int, Guid> typeToIdMapAttention;
4242
private readonly Dictionary<int, Guid> typeToIdMapEmbeddings;
4343
private readonly Dictionary<GapPath, List<GapPath>> connectedPathsMap;
@@ -72,11 +72,11 @@ public GraphAttentionPathsNeuralNetwork(List<GapGraph> graphs, int batchSize, in
7272
this.embeddingNeuralNetwork = new List<EmbeddingNeuralNetwork>();
7373
this.edgeAttentionNeuralNetwork = new List<EdgeAttentionNeuralNetwork>();
7474
this.typeToIdMap = new Dictionary<int, Guid>();
75-
this.typeToIdMapLstm = new Dictionary<int, Guid>();
75+
this.typeToIdMapTransformer = new Dictionary<int, Guid>();
7676
this.typeToIdMapAttention = new Dictionary<int, Guid>();
7777
this.typeToIdMapEmbeddings = new Dictionary<int, Guid>();
7878
this.connectedPathsMap = new Dictionary<GapPath, List<GapPath>>();
79-
this.lstmNeuralNetwork = new List<LstmNeuralNetwork>();
79+
this.transformerNeuralNetwork = new List<TransformerNeuralNetwork>();
8080
this.attentionMessagePassingNeuralNetwork = new List<AttentionMessagePassingNeuralNetwork>();
8181
this.gcnNeuralNetwork = new GcnNeuralNetwork(numLayers, 4, this.numFeatures, learningRate, clipValue);
8282
this.readoutNeuralNetwork = new ReadoutNeuralNetwork(numLayers, numQueries, 4, this.numFeatures, learningRate, clipValue);
@@ -106,10 +106,10 @@ public async Task Initialize()
106106

107107
for (int i = 0; i < 7; ++i)
108108
{
109-
var model = new LstmNeuralNetwork(this.numFeatures * (int)Math.Pow(2d, (double)this.numLayers), 500, this.numFeatures * (int)Math.Pow(2d, (double)this.numLayers) * 2, i + 2, this.numLayers, this.learningRate, this.clipValue);
110-
this.lstmNeuralNetwork.Add(model);
111-
await this.lstmNeuralNetwork[i].Initialize();
112-
this.modelLayers = this.modelLayers.Concat(this.lstmNeuralNetwork[i].ModelLayers).ToList();
109+
var model = new TransformerNeuralNetwork(this.numLayers, this.numQueries / 2, i + 2, this.numFeatures * (int)Math.Pow(2d, (double)this.numLayers), i + 2, this.learningRate, this.clipValue);
110+
this.transformerNeuralNetwork.Add(model);
111+
await this.transformerNeuralNetwork[i].Initialize();
112+
this.modelLayers = this.modelLayers.Concat(this.transformerNeuralNetwork[i].ModelLayers).ToList();
113113
}
114114

115115
for (int i = 0; i < 7; ++i)
@@ -135,27 +135,31 @@ public async Task Initialize()
135135
/// </summary>
136136
public void SaveWeights()
137137
{
138-
var weightStore = new WeightStore();
139-
weightStore.AddRange(this.modelLayers);
140-
weightStore.Save(new FileInfo(WEIGHTSSAVEPATH));
138+
Guid guid = Guid.NewGuid();
139+
var dir = $"E:\\store\\{guid}";
140+
Directory.CreateDirectory(dir);
141+
int index = 0;
142+
foreach (var modelLayer in this.modelLayers)
143+
{
144+
modelLayer.SaveWeightsAndMoments(new FileInfo($"{dir}\\layer{index}.json"));
145+
index++;
146+
}
141147
}
142148

143149
/// <summary>
144150
/// Apply the weights from the save path.
145151
/// </summary>
146152
public void ApplyWeights()
147153
{
148-
var weightStore = new WeightStore();
149-
weightStore.Load(new FileInfo(WEIGHTSSAVEPATH));
154+
var guid = "b93e5314-3a40-4353-8b44-0795cbfe0d4e";
155+
var dir = $"E:\\store\\{guid}";
150156
for (int i = 0; i < this.modelLayers.Count; ++i)
151157
{
152158
var modelLayer = this.modelLayers[i];
153-
var weights = weightStore.ToModelLayerWeights(i);
154-
modelLayer.ApplyWeights(weights);
159+
var file = new FileInfo($"{dir}\\layer{i}.json");
160+
modelLayer.LoadWeightsAndMoments(file);
161+
GC.Collect(GC.MaxGeneration, GCCollectionMode.Forced, true);
155162
}
156-
157-
weightStore = null;
158-
GC.Collect(GC.MaxGeneration, GCCollectionMode.Forced);
159163
}
160164

161165
/// <summary>
@@ -279,26 +283,26 @@ public DeepMatrix Forward()
279283
}
280284
}
281285

282-
Dictionary<int, List<DeepMatrix>> inputsByLength = new Dictionary<int, List<DeepMatrix>>();
286+
Dictionary<int, List<Matrix>> inputsByLength = new Dictionary<int, List<Matrix>>();
283287
Dictionary<(int Length, int Index), GapPath> pathIndexMap = new Dictionary<(int Length, int Index), GapPath>();
284288

285289
foreach (var graph in this.gapGraphs)
286290
{
287291
foreach (var path in graph.GapPaths)
288292
{
289293
var pathLength = path.Nodes.Count;
290-
var input = new DeepMatrix(pathLength, this.numFeatures * (int)Math.Pow(2d, (double)this.numLayers), 1);
291-
for (int i = 0; i < input.Depth; ++i)
294+
var input = new Matrix(pathLength, this.numFeatures * (int)Math.Pow(2d, (double)this.numLayers));
295+
for (int i = 0; i < input.Rows; ++i)
292296
{
293-
for (int j = 0; j < input.Rows; ++j)
297+
for (int j = 0; j < input.Cols; ++j)
294298
{
295-
input[i][j][0] = path.Nodes[i].FeatureVector[j][0];
299+
input[i][j] = path.Nodes[i].FeatureVector[j][0];
296300
}
297301
}
298302

299303
if (!inputsByLength.ContainsKey(pathLength))
300304
{
301-
inputsByLength[pathLength] = new List<DeepMatrix>();
305+
inputsByLength[pathLength] = new List<Matrix>();
302306
}
303307

304308
inputsByLength[pathLength].Add(input);
@@ -310,15 +314,14 @@ public DeepMatrix Forward()
310314
foreach (var length in inputsByLength.Keys)
311315
{
312316
var batchedInput = inputsByLength[length].ToArray(); // Array of DeepMatrix where each DeepMatrix is a timestep for all sequences in the batch
313-
var switched = CommonMatrixUtils.SwitchFirstTwoDimensions(batchedInput);
314-
var lstmNet = this.lstmNeuralNetwork[length - 2]; // Because a path must have a length of at least two
315-
lstmNet.Parameters.BatchSize = batchedInput.Length;
316-
lstmNet.InitializeState();
317-
lstmNet.AutomaticForwardPropagate(new FourDimensionalMatrix(switched));
317+
var transformerNet = this.transformerNeuralNetwork[length - 2]; // Because a path must have a length of at least two
318+
transformerNet.Parameters.BatchSize = batchedInput.Length;
319+
transformerNet.InitializeState();
320+
transformerNet.AutomaticForwardPropagate(new DeepMatrix(batchedInput));
318321
var id = Guid.NewGuid();
319-
this.typeToIdMapLstm.Add(length, id);
320-
lstmNet.StoreOperationIntermediates(id);
321-
var output = lstmNet.OutputPathFeatures[length - 1];
322+
this.typeToIdMapTransformer.Add(length, id);
323+
transformerNet.StoreOperationIntermediates(id);
324+
var output = transformerNet.Output;
322325
for (int i = 0; i < output.Depth; ++i)
323326
{
324327
var path = pathIndexMap[(length, i)];
@@ -522,7 +525,7 @@ public async Task Backward(DeepMatrix gradientOfLossWrtReadoutOutput)
522525
this.ApplyGradients(pathToGradientsMap);
523526

524527
Dictionary<int, List<Matrix>> pathLengthToGradientMap = new Dictionary<int, List<Matrix>>();
525-
Dictionary<(int, int), GapPath> indexesToPathMapLstm = new Dictionary<(int, int), GapPath>();
528+
Dictionary<(int, int), GapPath> indexesToPathMapTransformer = new Dictionary<(int, int), GapPath>();
526529
foreach (var graph in this.gapGraphs)
527530
{
528531
foreach (var path in graph.GapPaths)
@@ -531,12 +534,12 @@ public async Task Backward(DeepMatrix gradientOfLossWrtReadoutOutput)
531534
var gradient = pathToGradientsMap[path].Item1;
532535
if (pathLengthToGradientMap.ContainsKey(pathLength))
533536
{
534-
indexesToPathMapLstm.Add((pathLength, pathLengthToGradientMap[pathLength].Count), path);
537+
indexesToPathMapTransformer.Add((pathLength, pathLengthToGradientMap[pathLength].Count), path);
535538
pathLengthToGradientMap[pathLength].Add(gradient);
536539
}
537540
else
538541
{
539-
indexesToPathMapLstm.Add((pathLength, 0), path);
542+
indexesToPathMapTransformer.Add((pathLength, 0), path);
540543
pathLengthToGradientMap.Add(pathLength, new List<Matrix> { gradient });
541544
}
542545
}
@@ -545,24 +548,24 @@ public async Task Backward(DeepMatrix gradientOfLossWrtReadoutOutput)
545548
Dictionary<GapNode, Matrix> nodeToGradientMap = new Dictionary<GapNode, Matrix>();
546549
foreach (var key in pathLengthToGradientMap.Keys)
547550
{
548-
var lstmNet = this.lstmNeuralNetwork[key - 2];
549-
lstmNet.RestoreOperationIntermediates(this.typeToIdMapLstm[key]);
550-
var lstmGradient = CommonMatrixUtils.SwitchFirstTwoDimensions((await lstmNet.AutomaticBackwardPropagate(new DeepMatrix(pathLengthToGradientMap[key].ToArray()))).ToArray());
551+
var transformerNet = this.transformerNeuralNetwork[key - 2];
552+
transformerNet.RestoreOperationIntermediates(this.typeToIdMapTransformer[key]);
553+
var transformerGradient = await transformerNet.AutomaticBackwardPropagate(new DeepMatrix(pathLengthToGradientMap[key].ToArray()));
551554

552-
for (int i = 0; i < lstmGradient.Length; ++i)
555+
for (int i = 0; i < transformerGradient.Depth; ++i)
553556
{
554-
var path = indexesToPathMapLstm[(key, i)];
557+
var path = indexesToPathMapTransformer[(key, i)];
555558
var nodeCount = path.Nodes.Count;
556559
for (int j = 0; j < nodeCount; ++j)
557560
{
558561
var node = path.Nodes[j];
559562
if (!nodeToGradientMap.ContainsKey(node))
560563
{
561-
nodeToGradientMap.Add(node, lstmGradient[i][j]);
564+
nodeToGradientMap.Add(node, new Matrix(transformerGradient[i][j]).Transpose());
562565
}
563566
else
564567
{
565-
nodeToGradientMap[node].Accumulate(lstmGradient[i][j].ToArray());
568+
nodeToGradientMap[node].Accumulate(new Matrix(transformerGradient[i][j]).Transpose().ToArray());
566569
}
567570
}
568571
}

0 commit comments

Comments
 (0)