Skip to content

Commit a476f73

Browse files
committed
Update dynamics.
1 parent 2f0ff8a commit a476f73

File tree

6 files changed

+153
-3
lines changed

6 files changed

+153
-3
lines changed

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/GlyphNet.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,11 @@ public void ApplyGradients()
180180
Console.WriteLine($"Max Mag 0: {maxMag0}, Max Mag 1: {maxMag1}");
181181

182182
SquaredArclengthEuclideanLossOperation arclengthLoss0 = SquaredArclengthEuclideanLossOperation.Instantiate(gatNet);
183-
var loss0 = arclengthLoss0.Forward(targetedSum0, (3 * Math.PI) / 4d);
183+
var loss0 = arclengthLoss0.Forward(targetedSum0, (3 * Math.PI) / 4d, 0);
184184
var gradient0 = arclengthLoss0.Backward();
185185

186186
SquaredArclengthEuclideanLossOperation arclengthLoss1 = SquaredArclengthEuclideanLossOperation.Instantiate(gatNet);
187-
var loss1 = arclengthLoss1.Forward(targetedSum1, (1 * Math.PI) / 4d);
187+
var loss1 = arclengthLoss1.Forward(targetedSum1, (1 * Math.PI) / 4d, 1);
188188
var gradient1 = arclengthLoss1.Backward();
189189

190190
SquaredArclengthEuclideanMagnitudeLossOperation arclengthLoss = SquaredArclengthEuclideanMagnitudeLossOperation.Instantiate(gatNet);

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/GlyphNetwork/GlyphTrainingDynamics.cs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,137 @@ public class GlyphTrainingDynamics
4040

4141
public double GradMagnitudeLossY { get; set; }
4242

43+
public double ActualAngle { get; set; }
44+
45+
public double ActualAngleTarget0 { get; set; }
46+
47+
public double ActualAngleTarget1 { get; set; }
48+
49+
public double TargetAngle { get; set; }
50+
51+
public double TargetAngle0 { get; set; }
52+
53+
public double TargetAngle1 { get; set; }
54+
4355
private GlyphTrainingDynamics()
4456
{
4557

4658
}
59+
60+
public (double, double) CalculateGradientDirection(double x, double y, int targetIndex = -1)
61+
{
62+
var actualAngle = Math.Atan2(y, x);
63+
64+
var targetAngle = TargetAngle;
65+
if (targetIndex == 0)
66+
{
67+
targetAngle = TargetAngle0;
68+
} else if (targetIndex == 1)
69+
{
70+
targetAngle = TargetAngle1;
71+
}
72+
73+
var quadrant = GetQuadrant(actualAngle);
74+
var targetQuadrant = GetQuadrant(targetAngle);
75+
var oppositeAngle = CalculateOppositeAngle(targetAngle);
76+
if (targetQuadrant == 1)
77+
{
78+
if (quadrant == 1)
79+
{
80+
if (actualAngle < targetAngle)
81+
{
82+
return (1, -1); // x increase, y decrease
83+
}
84+
else
85+
{
86+
return (-1, 1); // x decrease, y increase
87+
}
88+
}
89+
else if (quadrant == 2)
90+
{
91+
return (-1, -1); // x decrease, y decrease
92+
}
93+
else if (quadrant == 3)
94+
{
95+
if (actualAngle < oppositeAngle)
96+
{
97+
return (1, -1); // x increase, y decrease
98+
}
99+
else
100+
{
101+
return (-1, 1); // x decrease, y increase
102+
}
103+
}
104+
else
105+
{
106+
return (-1, -1); // x decrease, y decrease
107+
}
108+
} else if (targetQuadrant == 2)
109+
{
110+
if (quadrant == 1)
111+
{
112+
return (1, -1); // x increase, y decrease
113+
}
114+
else if (quadrant == 2)
115+
{
116+
if (actualAngle < targetAngle)
117+
{
118+
return (1, 1); // x increase, y increase
119+
}
120+
else
121+
{
122+
return (-1, -1); // x decrease, y decrease
123+
}
124+
}
125+
else if (quadrant == 3)
126+
{
127+
return (1, -1); // x increase, y decrease
128+
}
129+
else
130+
{
131+
if (actualAngle < oppositeAngle)
132+
{
133+
return (1, 1); // x increase, y increase
134+
}
135+
else
136+
{
137+
return (1, -1); // x decrease, y decrease
138+
}
139+
}
140+
}
141+
142+
throw new InvalidOperationException("Unsupported target quadrant");
143+
}
144+
145+
private int GetQuadrant(double angleInRadians)
146+
{
147+
// Normalize the angle to be within 0 to 2pi radians
148+
double normalizedAngle = angleInRadians % (2 * Math.PI);
149+
// Adjust if negative to ensure it falls within the 0 to 2pi range
150+
if (normalizedAngle < 0)
151+
normalizedAngle += 2 * Math.PI;
152+
153+
// Determine the quadrant
154+
if (normalizedAngle >= 0 && normalizedAngle < Math.PI / 2)
155+
return 1;
156+
else if (normalizedAngle >= Math.PI / 2 && normalizedAngle < Math.PI)
157+
return 2;
158+
else if (normalizedAngle >= Math.PI && normalizedAngle < 3 * Math.PI / 2)
159+
return 3;
160+
else // normalizedAngle >= 3 * Math.PI / 2 && normalizedAngle < 2 * Math.PI
161+
return 4;
162+
}
163+
164+
private double CalculateOppositeAngle(double targetAngle)
165+
{
166+
// Add π to the target angle to find the opposite angle
167+
double oppositeAngle = targetAngle + Math.PI;
168+
169+
// Normalize the opposite angle to be within [-2π, 2π]
170+
if (oppositeAngle > 2 * Math.PI) oppositeAngle -= 2 * Math.PI;
171+
else if (oppositeAngle < -2 * Math.PI) oppositeAngle += 2 * Math.PI;
172+
173+
return oppositeAngle;
174+
}
47175
}
48176
}

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/VectorNetwork/RMAD/ElementwiseVectorCartesianRotationAndSumOperation.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ public override BackwardResult Backward(Matrix dOutput)
104104
dInputVectors[vectorIndex, 0] = (dOutput[0, 0] * cosTheta) + (-dOutput[0, 1] * sinTheta);
105105
dInputVectors[vectorIndex, 1] = (dOutput[0, 0] * sinTheta) + (dOutput[0, 1] * cosTheta);
106106

107+
var gradientDirection = GlyphTrainingDynamics.Instance.CalculateGradientDirection(inputVectors[vectorIndex, 0], inputVectors[vectorIndex, 1], (int) rotationTargets[i, j]);
108+
dInputVectors[vectorIndex, 0] = Math.Abs(dInputVectors[vectorIndex, 0]) * gradientDirection.Item1;
109+
dInputVectors[vectorIndex, 1] = Math.Abs(dInputVectors[vectorIndex, 1]) * gradientDirection.Item2;
110+
107111
vectorIndex++;
108112
}
109113
}

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/VectorNetwork/RMAD/ElementwiseVectorCartesianTargetedSumOperation.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ public override BackwardResult Backward(Matrix dOutput)
7979
{
8080
// Initialize dInputVectors with the same shape as the forward input vectors
8181
Matrix dInputVectors = new Matrix(225, 2);
82+
var inputVectors = this.inputVectors;
8283

8384
int vectorIndex = 0;
8485
for (int i = 0; i < 15; i++)
@@ -89,6 +90,10 @@ public override BackwardResult Backward(Matrix dOutput)
8990
{
9091
dInputVectors[vectorIndex, 0] = dOutput[0, 0];
9192
dInputVectors[vectorIndex, 1] = dOutput[0, 1];
93+
94+
var gradientDirection = GlyphTrainingDynamics.Instance.CalculateGradientDirection(inputVectors[vectorIndex, 0], inputVectors[vectorIndex, 1], this.target);
95+
dInputVectors[vectorIndex, 0] = Math.Abs(dInputVectors[vectorIndex, 0]) * gradientDirection.Item1;
96+
dInputVectors[vectorIndex, 1] = Math.Abs(dInputVectors[vectorIndex, 1]) * gradientDirection.Item2;
9297
}
9398

9499
vectorIndex++;

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/VectorNetwork/RMAD/SquaredArclengthEuclideanLossOperation.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//------------------------------------------------------------------------------
66
namespace ParallelReverseAutoDiff.RMAD
77
{
8+
using ParallelReverseAutoDiff.GravNetExample.GlyphNetwork;
89
using System;
910

1011
/// <summary>
@@ -39,8 +40,9 @@ public static SquaredArclengthEuclideanLossOperation Instantiate(NeuralNetwork n
3940
/// </summary>
4041
/// <param name="predictions">The predictions matrix.</param>
4142
/// <param name="targetAngle">The target angle.</param>
43+
/// <param name="targetIndex">The target index.</param>
4244
/// <returns>The scalar loss value.</returns>
43-
public Matrix Forward(Matrix predictions, double targetAngle)
45+
public Matrix Forward(Matrix predictions, double targetAngle, int targetIndex = -1)
4446
{
4547
this.targetAngle = targetAngle;
4648
var xOutput = predictions[0, 0];
@@ -54,6 +56,15 @@ public Matrix Forward(Matrix predictions, double targetAngle)
5456

5557
double magnitude = Math.Sqrt(xOutput * xOutput + yOutput * yOutput);
5658
this.actualAngle = Math.Atan2(yOutput, xOutput);
59+
if (targetIndex == 0)
60+
{
61+
GlyphTrainingDynamics.Instance.ActualAngleTarget0 = this.actualAngle;
62+
GlyphTrainingDynamics.Instance.TargetAngle0 = targetAngle;
63+
} else if (targetIndex == 1)
64+
{
65+
GlyphTrainingDynamics.Instance.ActualAngleTarget1 = this.actualAngle;
66+
GlyphTrainingDynamics.Instance.TargetAngle1 = targetAngle;
67+
}
5768

5869
var xTarget = Math.Cos(targetAngle) * magnitude;
5970
this.xTarget = xTarget;

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/VectorNetwork/RMAD/SquaredArclengthEuclideanMagnitudeLossOperation.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ public Matrix Forward(Matrix predictions, double targetAngle, double maxMagnitud
5454

5555
double magnitude = Math.Sqrt(xOutput * xOutput + yOutput * yOutput);
5656
this.actualAngle = Math.Atan2(yOutput, xOutput);
57+
GlyphTrainingDynamics.Instance.ActualAngle = this.actualAngle;
58+
GlyphTrainingDynamics.Instance.TargetAngle = targetAngle;
5759

5860
var xTarget = Math.Cos(targetAngle) * magnitude;
5961
this.xTarget = xTarget;

0 commit comments

Comments
 (0)