|
6 | 6 | namespace ParallelReverseAutoDiff.RMAD |
7 | 7 | { |
8 | 8 | using System; |
9 | | - using System.Collections.Generic; |
| 9 | + using System.Diagnostics; |
10 | 10 | using System.Threading.Tasks; |
11 | 11 |
|
12 | 12 | /// <summary> |
@@ -93,26 +93,47 @@ public void Optimize(IModelLayer[] layers) |
93 | 93 | private void UpdateWeightWithAdam(Matrix w, Matrix mW, Matrix vW, Matrix gradient, double beta1, double beta2, double epsilon) |
94 | 94 | { |
95 | 95 | // Update biased first moment estimate |
96 | | - mW = MatrixUtils.MatrixAdd(MatrixUtils.ScalarMultiply(beta1, mW), MatrixUtils.ScalarMultiply(1 - beta1, gradient)); |
| 96 | + var firstMoment = MatrixUtils.MatrixAdd(MatrixUtils.ScalarMultiply(beta1, mW), MatrixUtils.ScalarMultiply(1 - beta1, gradient)); |
97 | 97 |
|
98 | 98 | // Update biased second raw moment estimate |
99 | | - vW = MatrixUtils.MatrixAdd(MatrixUtils.ScalarMultiply(beta2, vW), MatrixUtils.ScalarMultiply(1 - beta2, MatrixUtils.HadamardProduct(gradient, gradient))); |
| 99 | + var secondMoment = MatrixUtils.MatrixAdd(MatrixUtils.ScalarMultiply(beta2, vW), MatrixUtils.ScalarMultiply(1 - beta2, MatrixUtils.HadamardProduct(gradient, gradient))); |
100 | 100 |
|
101 | 101 | // Compute bias-corrected first moment estimate |
102 | | - Matrix mW_hat = MatrixUtils.ScalarMultiply(1 / (1 - Math.Pow(beta1, this.network.Parameters.AdamIteration)), mW); |
| 102 | + Matrix mW_hat = MatrixUtils.ScalarMultiply(1 / (1 - Math.Pow(beta1, this.network.Parameters.AdamIteration)), firstMoment); |
103 | 103 |
|
104 | 104 | // Compute bias-corrected second raw moment estimate |
105 | | - Matrix vW_hat = MatrixUtils.ScalarMultiply(1 / (1 - Math.Pow(beta2, this.network.Parameters.AdamIteration)), vW); |
| 105 | + Matrix vW_hat = MatrixUtils.ScalarMultiply(1 / (1 - Math.Pow(beta2, this.network.Parameters.AdamIteration)), secondMoment); |
106 | 106 |
|
107 | 107 | // Update weights |
108 | 108 | for (int i = 0; i < w.Length; i++) |
109 | 109 | { |
110 | 110 | for (int j = 0; j < w[0].Length; j++) |
111 | 111 | { |
112 | 112 | double weightReductionValue = this.network.Parameters.LearningRate * mW_hat[i][j] / (Math.Sqrt(vW_hat[i][j]) + epsilon); |
| 113 | +#if DEBUG |
| 114 | + Debug.WriteLine(weightReductionValue + " vs gradient: " + gradient[i][j]); |
| 115 | +#endif |
113 | 116 | w[i][j] -= weightReductionValue; |
114 | 117 | } |
115 | 118 | } |
| 119 | + |
| 120 | + // Update first moment |
| 121 | + for (int i = 0; i < mW.Length; i++) |
| 122 | + { |
| 123 | + for (int j = 0; j < mW[0].Length; j++) |
| 124 | + { |
| 125 | + mW[i][j] = firstMoment[i][j]; |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + // Update second moment |
| 130 | + for (int i = 0; i < vW.Length; i++) |
| 131 | + { |
| 132 | + for (int j = 0; j < vW[0].Length; j++) |
| 133 | + { |
| 134 | + vW[i][j] = secondMoment[i][j]; |
| 135 | + } |
| 136 | + } |
116 | 137 | } |
117 | 138 | } |
118 | 139 | } |
0 commit comments