Skip to content

Commit baa5d5d

Browse files
committed
Add training dynamics.
1 parent f555681 commit baa5d5d

File tree

5 files changed

+68
-4
lines changed

5 files changed

+68
-4
lines changed

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/GlyphNet.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,12 @@ public void ApplyGradients()
179179

180180
Console.WriteLine($"Max Mag 0: {maxMag0}, Max Mag 1: {maxMag1}");
181181

182-
SquaredArclengthEuclideanMagnitudeLossOperation arclengthLoss0 = SquaredArclengthEuclideanMagnitudeLossOperation.Instantiate(gatNet);
183-
var loss0 = arclengthLoss0.Forward(targetedSum0, (3 * Math.PI) / 4d, maxMag0);
182+
SquaredArclengthEuclideanLossOperation arclengthLoss0 = SquaredArclengthEuclideanLossOperation.Instantiate(gatNet);
183+
var loss0 = arclengthLoss0.Forward(targetedSum0, (3 * Math.PI) / 4d);
184184
var gradient0 = arclengthLoss0.Backward();
185185

186-
SquaredArclengthEuclideanMagnitudeLossOperation arclengthLoss1 = SquaredArclengthEuclideanMagnitudeLossOperation.Instantiate(gatNet);
187-
var loss1 = arclengthLoss1.Forward(targetedSum1, (1 * Math.PI) / 4d, maxMag1);
186+
SquaredArclengthEuclideanLossOperation arclengthLoss1 = SquaredArclengthEuclideanLossOperation.Instantiate(gatNet);
187+
var loss1 = arclengthLoss1.Forward(targetedSum1, (1 * Math.PI) / 4d);
188188
var gradient1 = arclengthLoss1.Backward();
189189

190190
SquaredArclengthEuclideanMagnitudeLossOperation arclengthLoss = SquaredArclengthEuclideanMagnitudeLossOperation.Instantiate(gatNet);

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/GlyphNetTrainer.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ await files.WithRepeatAsync(async (pngFile, token) =>
107107
net.SaveWeights();
108108
}
109109

110+
token.Repeat();
110111
//if (token.UsageCount == 0)
111112
//{
112113
// token.Repeat(2);
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
namespace ParallelReverseAutoDiff.GravNetExample.GlyphNetwork
2+
{
3+
using ParallelReverseAutoDiff.RMAD;
4+
5+
public class GlyphTrainingDynamics
6+
{
7+
private static readonly Lazy<GlyphTrainingDynamics> lazy = new Lazy<GlyphTrainingDynamics>(() => new GlyphTrainingDynamics(), true);
8+
9+
public static GlyphTrainingDynamics Instance { get { return lazy.Value; } }
10+
11+
public Matrix[] PreviousTargetedSum { get; set; } = new Matrix[2];
12+
13+
public Matrix[] LastTargetedSum { get; set; } = new Matrix[2];
14+
15+
public double PreviousAngleLoss { get; set; }
16+
17+
public double PreviousEuclideanLoss { get; set; }
18+
19+
public double PreviousMagnitudeLoss { get; set; }
20+
21+
public double AngleLoss { get; set; }
22+
23+
public double EuclideanLoss { get; set; }
24+
25+
public double MagnitudeLoss { get; set; }
26+
27+
public double GradAngleLossX { get; set; }
28+
29+
public double GradAngleLossY { get; set; }
30+
31+
public double GradEuclideanLossX { get; set; }
32+
33+
public double GradEuclideanLossY { get; set; }
34+
35+
public double GradMagnitudeLossX { get; set; }
36+
37+
public double GradMagnitudeLossY { get; set; }
38+
39+
private GlyphTrainingDynamics()
40+
{
41+
42+
}
43+
}
44+
}

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/VectorNetwork/RMAD/ElementwiseVectorCartesianTargetedSumOperation.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
//------------------------------------------------------------------------------
66
namespace ParallelReverseAutoDiff.RMAD
77
{
8+
using ParallelReverseAutoDiff.GravNetExample.GlyphNetwork;
9+
810
/// <summary>
911
/// Element-wise cartesian sum operation.
1012
/// </summary>
@@ -33,6 +35,8 @@ public static IOperation Instantiate(NeuralNetwork net)
3335
/// <returns>The output of the element-wise vector rotation operation.</returns>
3436
public Matrix Forward(Matrix inputVectors, Matrix rotationTargets, int target)
3537
{
38+
GlyphTrainingDynamics.Instance.PreviousTargetedSum[target] = GlyphTrainingDynamics.Instance.LastTargetedSum[target];
39+
GlyphTrainingDynamics.Instance.LastTargetedSum[target] = inputVectors;
3640
this.inputVectors = inputVectors;
3741
this.rotationTargets = rotationTargets;
3842
this.target = target;

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/VectorNetwork/RMAD/SquaredArclengthEuclideanMagnitudeLossOperation.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//------------------------------------------------------------------------------
66
namespace ParallelReverseAutoDiff.RMAD
77
{
8+
using ParallelReverseAutoDiff.GravNetExample.GlyphNetwork;
89
using System;
910

1011
/// <summary>
@@ -90,6 +91,13 @@ public Matrix Forward(Matrix predictions, double targetAngle, double maxMagnitud
9091

9192
double arcLength = Math.Pow(radius * theta, 2);
9293

94+
GlyphTrainingDynamics.Instance.PreviousAngleLoss = GlyphTrainingDynamics.Instance.AngleLoss;
95+
GlyphTrainingDynamics.Instance.PreviousEuclideanLoss = GlyphTrainingDynamics.Instance.EuclideanLoss;
96+
GlyphTrainingDynamics.Instance.PreviousMagnitudeLoss = GlyphTrainingDynamics.Instance.MagnitudeLoss;
97+
GlyphTrainingDynamics.Instance.AngleLoss = arcLength;
98+
GlyphTrainingDynamics.Instance.EuclideanLoss = distanceAccum;
99+
GlyphTrainingDynamics.Instance.MagnitudeLoss = magnitudeDiscrepancy;
100+
93101
// Compute the squared magnitude of the loss
94102
double lossMagnitude = (arcLength + distanceAccum + magnitudeDiscrepancy) / 3d;
95103

@@ -121,6 +129,13 @@ public Matrix Backward()
121129
dPredictions[0, 0] = (cX * gradX) + eX + dMagDiscrepancy_dX;
122130
dPredictions[0, 1] = (cY * gradY) + eY + dMagDiscrepancy_dY;
123131

132+
GlyphTrainingDynamics.Instance.GradAngleLossX = cX * gradX;
133+
GlyphTrainingDynamics.Instance.GradAngleLossY = cY * gradY;
134+
GlyphTrainingDynamics.Instance.GradEuclideanLossX = eX;
135+
GlyphTrainingDynamics.Instance.GradEuclideanLossY = eY;
136+
GlyphTrainingDynamics.Instance.GradMagnitudeLossX = dMagDiscrepancy_dX;
137+
GlyphTrainingDynamics.Instance.GradMagnitudeLossY = dMagDiscrepancy_dY;
138+
124139
if (double.IsNaN(dPredictions[0, 0]) || double.IsNaN(dPredictions[0, 1]))
125140
{
126141

0 commit comments

Comments
 (0)