Skip to content

Commit f35fa43

Browse files
committed
Update networks.
1 parent c145aa9 commit f35fa43

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

test/ParallelReverseAutoDiff.Test/GraphAttentionPaths/Embedding/EmbeddingNeuralNetwork.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ public void AutomaticForwardPropagate(DeepMatrix input)
165165
/// </summary>
166166
/// <param name="gradient">The gradient of the loss.</param>
167167
/// <returns>The gradient.</returns>
168-
public async Task<DeepMatrix> AutomaticBackwardPropagate(DeepMatrix gradient)
168+
public async Task<Matrix> AutomaticBackwardPropagate(DeepMatrix gradient)
169169
{
170170
IOperationBase? backwardStartOperation = null;
171171
backwardStartOperation = this.computationGraph["vector_concatenate_trans_0_0"];
@@ -179,7 +179,7 @@ public async Task<DeepMatrix> AutomaticBackwardPropagate(DeepMatrix gradient)
179179
}
180180

181181
IOperationBase? backwardEndOperation = this.computationGraph["batch_embeddings_0_0"];
182-
return backwardEndOperation.CalculatedGradient[0] as DeepMatrix ?? throw new InvalidOperationException("Calculated gradient should not be null.");
182+
return backwardEndOperation.CalculatedGradient[1] as Matrix ?? throw new InvalidOperationException("Calculated gradient should not be null.");
183183
}
184184

185185
/// <summary>

test/ParallelReverseAutoDiff.Test/GraphAttentionPaths/Transformer/TransformerNeuralNetwork.cs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ public TransformerNeuralNetwork(int numLayers, int numQueries, int numPaths, int
6060
.AddModelElementGroup("VB", new[] { 1, numInputOutputFeatures }, InitializationType.Zeroes);
6161
var inputLayer = inputLayerBuilder.Build();
6262
this.inputLayers.Add(inputLayer);
63+
numInputFeatures = numInputOutputFeatures;
6364
}
6465

6566
this.nestedLayers = new List<IModelLayer>();
@@ -73,6 +74,7 @@ public TransformerNeuralNetwork(int numLayers, int numQueries, int numPaths, int
7374
.AddModelElementGroup("QB", new[] { numQueries, 1, numNestedOutputFeatures }, InitializationType.Zeroes);
7475
var nestedLayer = nestedLayerBuilder.Build();
7576
this.nestedLayers.Add(nestedLayer);
77+
numNestedFeatures = numNestedOutputFeatures;
7678
outputFeaturesList.Add(numNestedOutputFeatures * numQueries);
7779
}
7880

@@ -85,18 +87,11 @@ public TransformerNeuralNetwork(int numLayers, int numQueries, int numPaths, int
8587
.AddModelElementGroup("F2W", new[] { outputFeaturesList[i], (outputFeaturesList[i] / 2) }, InitializationType.Xavier)
8688
.AddModelElementGroup("F2B", new[] { 1, (outputFeaturesList[i] / 2) }, InitializationType.Xavier)
8789
.AddModelElementGroup("Beta", new[] { 1, 1 }, InitializationType.He);
88-
if (i < (this.NumLayers - 1))
89-
{
90-
outputLayerBuilder
91-
.AddModelElementGroup("R", new[] { (outputFeaturesList[i] / 2), this.NumFeatures }, InitializationType.Xavier)
92-
.AddModelElementGroup("RB", new[] { 1, this.NumFeatures }, InitializationType.Zeroes);
93-
}
94-
else
95-
{
90+
9691
outputLayerBuilder
9792
.AddModelElementGroup("R", new[] { (outputFeaturesList[i] / 2), this.NumFeatures * 2 }, InitializationType.Xavier)
9893
.AddModelElementGroup("RB", new[] { 1, this.NumFeatures * 2 }, InitializationType.Zeroes);
99-
}
94+
10095
var outputLayer = outputLayerBuilder.Build();
10196
this.outputLayers.Add(outputLayer);
10297
}

0 commit comments

Comments
 (0)