Skip to content

Commit 5aeb4c0

Browse files
committed
Update adam optimizer.
1 parent 77c4d9f commit 5aeb4c0

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

src/RMAD/AdamOptimizer.cs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
namespace ParallelReverseAutoDiff.RMAD
77
{
88
using System;
9-
using System.Collections.Generic;
9+
using System.Diagnostics;
1010
using System.Threading.Tasks;
1111

1212
/// <summary>
@@ -93,26 +93,47 @@ public void Optimize(IModelLayer[] layers)
9393
private void UpdateWeightWithAdam(Matrix w, Matrix mW, Matrix vW, Matrix gradient, double beta1, double beta2, double epsilon)
9494
{
9595
// Update biased first moment estimate
96-
mW = MatrixUtils.MatrixAdd(MatrixUtils.ScalarMultiply(beta1, mW), MatrixUtils.ScalarMultiply(1 - beta1, gradient));
96+
var firstMoment = MatrixUtils.MatrixAdd(MatrixUtils.ScalarMultiply(beta1, mW), MatrixUtils.ScalarMultiply(1 - beta1, gradient));
9797

9898
// Update biased second raw moment estimate
99-
vW = MatrixUtils.MatrixAdd(MatrixUtils.ScalarMultiply(beta2, vW), MatrixUtils.ScalarMultiply(1 - beta2, MatrixUtils.HadamardProduct(gradient, gradient)));
99+
var secondMoment = MatrixUtils.MatrixAdd(MatrixUtils.ScalarMultiply(beta2, vW), MatrixUtils.ScalarMultiply(1 - beta2, MatrixUtils.HadamardProduct(gradient, gradient)));
100100

101101
// Compute bias-corrected first moment estimate
102-
Matrix mW_hat = MatrixUtils.ScalarMultiply(1 / (1 - Math.Pow(beta1, this.network.Parameters.AdamIteration)), mW);
102+
Matrix mW_hat = MatrixUtils.ScalarMultiply(1 / (1 - Math.Pow(beta1, this.network.Parameters.AdamIteration)), firstMoment);
103103

104104
// Compute bias-corrected second raw moment estimate
105-
Matrix vW_hat = MatrixUtils.ScalarMultiply(1 / (1 - Math.Pow(beta2, this.network.Parameters.AdamIteration)), vW);
105+
Matrix vW_hat = MatrixUtils.ScalarMultiply(1 / (1 - Math.Pow(beta2, this.network.Parameters.AdamIteration)), secondMoment);
106106

107107
// Update weights
108108
for (int i = 0; i < w.Length; i++)
109109
{
110110
for (int j = 0; j < w[0].Length; j++)
111111
{
112112
double weightReductionValue = this.network.Parameters.LearningRate * mW_hat[i][j] / (Math.Sqrt(vW_hat[i][j]) + epsilon);
113+
#if DEBUG
114+
Debug.WriteLine(weightReductionValue + " vs gradient: " + gradient[i][j]);
115+
#endif
113116
w[i][j] -= weightReductionValue;
114117
}
115118
}
119+
120+
// Update first moment
121+
for (int i = 0; i < mW.Length; i++)
122+
{
123+
for (int j = 0; j < mW[0].Length; j++)
124+
{
125+
mW[i][j] = firstMoment[i][j];
126+
}
127+
}
128+
129+
// Update second moment
130+
for (int i = 0; i < vW.Length; i++)
131+
{
132+
for (int j = 0; j < vW[0].Length; j++)
133+
{
134+
vW[i][j] = secondMoment[i][j];
135+
}
136+
}
116137
}
117138
}
118139
}

test/ParallelReverseAutoDiff.Test/GraphAttentionPaths/GraphAttentionPathsNeuralNetwork.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ public void ApplyWeights()
164164
}
165165
}
166166

167+
/// <summary>
168+
/// Apply the gradients to update the weights.
169+
/// </summary>
170+
public void ApplyGradients()
171+
{
172+
var clipper = this.readoutNeuralNetwork.Utilities.GradientClipper;
173+
clipper.Clip(this.modelLayers.ToArray());
174+
var adamOptimizer = this.readoutNeuralNetwork.Utilities.AdamOptimizer;
175+
adamOptimizer.Optimize(this.modelLayers.ToArray());
176+
GradientClearer clearer = new GradientClearer();
177+
clearer.Clear(this.modelLayers.ToArray());
178+
}
179+
167180
/// <summary>
168181
/// Make a forward pass through the computation graph.
169182
/// </summary>

test/ParallelReverseAutoDiff.Test/GraphAttentionPathsNeuralNetworkTest.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ public async Task GivenGraphAttentionPathsNeuralNetworkMiniBatch_ProcessesMiniBa
2424

2525
int batchSize = 4;
2626

27-
GraphAttentionPathsNeuralNetwork neuralNetwork = new GraphAttentionPathsNeuralNetwork(graphs, batchSize, 16, 115, 5, 2, 4, 0.001d, 4d);
27+
GraphAttentionPathsNeuralNetwork neuralNetwork = new GraphAttentionPathsNeuralNetwork(graphs, batchSize, 16, 115, 3, 2, 4, 0.001d, 4d);
2828
await neuralNetwork.Initialize();
2929
DeepMatrix gradientOfLoss = neuralNetwork.Forward();
3030
await neuralNetwork.Backward(gradientOfLoss);
31+
neuralNetwork.ApplyGradients();
3132
}
3233
finally
3334
{

0 commit comments

Comments
 (0)