Skip to content

Commit 01e520c

Browse files
committed
Update trainer.
1 parent 890ef0a commit 01e520c

File tree

4 files changed

+59
-50
lines changed

4 files changed

+59
-50
lines changed

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/TiledNet.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,18 @@ public void ApplyGradients()
147147
/// Make a forward pass through the computation graph.
148148
/// </summary>
149149
/// <returns>The gradient of the loss wrt the output.</returns>
150-
public (Matrix, Matrix, Matrix) Forward(Matrix input, Matrix rotationTargets, double targetAngle)
150+
public (Matrix, Matrix, Matrix) Forward(Matrix input, double[,] percentages)
151151
{
152152

153153
var gatNet = this.TiledNetwork;
154-
gatNet.TargetAngle = targetAngle;
154+
//gatNet.TargetAngle = targetAngle;
155155
gatNet.InitializeState();
156-
gatNet.RotationTargets.Replace(rotationTargets.ToArray());
156+
//gatNet.RotationTargets.Replace(rotationTargets.ToArray());
157157
gatNet.AutomaticForwardPropagate(input);
158158
var output = gatNet.Output;
159159

160160
SquaredArclengthEuclideanLossOperation arclengthLoss = SquaredArclengthEuclideanLossOperation.Instantiate(gatNet);
161-
var loss = arclengthLoss.Forward(output, targetAngle);
161+
var loss = arclengthLoss.Forward(output, Math.PI / 4);
162162
var gradient = arclengthLoss.Backward();
163163

164164
return (gradient, output, loss);

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/TiledNetTrainer.cs

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ await files.WithRepeatAsync(async (pngFile, token) =>
3535
List<List<double>> data = new List<List<double>>();
3636
int count = 0;
3737
int totalCount = 0;
38+
double[,] values = new double[512, 512];
3839
for (int i = 0; i < 512; ++i)
3940
{
4041
var subvect = new List<double>();
@@ -45,6 +46,7 @@ await files.WithRepeatAsync(async (pngFile, token) =>
4546
{
4647
count++;
4748
}
49+
values[i, j] = value;
4850
subvect.Add(value);
4951
totalCount++;
5052
}
@@ -56,20 +58,22 @@ await files.WithRepeatAsync(async (pngFile, token) =>
5658
int uIndex = file.IndexOf("_");
5759
var prefix = file.Substring(0, uIndex);
5860

59-
var glyphFile = pngFile.Replace("\\" + prefix, "\\" + prefix + "_glyph").Replace("svg\\", "svg-glyph\\");
60-
Node[,] glyphNodes = ImageSerializer.DeserializeImageWithoutAntiAlias(glyphFile);
61-
Matrix rotationTargets = new Matrix(15, 15);
62-
Vector3[] glyphs = new Vector3[225];
63-
int m = 0;
64-
for (int k = 0; k < 15; ++k)
65-
{
66-
for (int l = 0; l < 15; ++l)
67-
{
68-
rotationTargets[k, l] = glyphNodes[k, l].IsForeground ? 1d : 0d;
69-
glyphs[m] = new Vector3(0f, 0f, (float)rotationTargets[k, l]);
70-
m++;
71-
}
72-
}
61+
var percentages = CalculatePercentagesAboveThreshold(values);
62+
63+
//var glyphFile = pngFile.Replace("\\" + prefix, "\\" + prefix + "_glyph").Replace("svg\\", "svg-glyph\\");
64+
//Node[,] glyphNodes = ImageSerializer.DeserializeImageWithoutAntiAlias(glyphFile);
65+
//Matrix rotationTargets = new Matrix(15, 15);
66+
//Vector3[] glyphs = new Vector3[225];
67+
//int m = 0;
68+
//for (int k = 0; k < 15; ++k)
69+
//{
70+
// for (int l = 0; l < 15; ++l)
71+
// {
72+
// rotationTargets[k, l] = glyphNodes[k, l].IsForeground ? 1d : 0d;
73+
// glyphs[m] = new Vector3(0f, 0f, (float)rotationTargets[k, l]);
74+
// m++;
75+
// }
76+
//}
7377

7478
Matrix matrix = new Matrix(data.Count, data[0].Count);
7579
for (int j = 0; j < data.Count; j++)
@@ -82,7 +86,7 @@ await files.WithRepeatAsync(async (pngFile, token) =>
8286

8387
i++;
8488

85-
var res = net.Forward(matrix, rotationTargets, perc > 35d ? 3 * Math.PI / 4d : Math.PI / 4d);
89+
var res = net.Forward(matrix, percentages);
8690
var gradient = res.Item1;
8791
var output = res.Item2;
8892
var loss = res.Item3;
@@ -131,5 +135,37 @@ await files.WithRepeatAsync(async (pngFile, token) =>
131135
CudaBlas.Instance.Dispose();
132136
}
133137
}
138+
139+
private double[,] CalculatePercentagesAboveThreshold(double[,] input, double threshold = 0.5)
140+
{
141+
int inputWidth = input.GetLength(0); // 512
142+
int inputHeight = input.GetLength(1); // 512
143+
int sectionSize = inputWidth / 8; // Assuming the input is always 512x512 and output grid is 8x8
144+
145+
double[,] percentages = new double[8, 8];
146+
147+
for (int sectionX = 0; sectionX < 8; sectionX++)
148+
{
149+
for (int sectionY = 0; sectionY < 8; sectionY++)
150+
{
151+
int aboveThresholdCount = 0;
152+
for (int x = sectionX * sectionSize; x < (sectionX + 1) * sectionSize; x++)
153+
{
154+
for (int y = sectionY * sectionSize; y < (sectionY + 1) * sectionSize; y++)
155+
{
156+
if (input[x, y] > threshold)
157+
{
158+
aboveThresholdCount++;
159+
}
160+
}
161+
}
162+
163+
double totalValues = sectionSize * sectionSize;
164+
percentages[sectionX, sectionY] = (double)aboveThresholdCount / totalValues;
165+
}
166+
}
167+
168+
return percentages;
169+
}
134170
}
135171
}

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/TiledNetwork/TiledNetwork.cs

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -360,37 +360,7 @@ public async Task<Matrix> AutomaticBackwardPropagate(Matrix gradient, Matrix gra
360360
var output = this.computationGraph["output_0_0"];
361361
Matrix gg = gradient;
362362

363-
if (gradient0 != null)
364-
{
365-
var outputResult = (output as ElementwiseVectorCartesianRotationAndSumOperation).Backward(gradient).Results[0] as Matrix;
366-
367-
var targetedSum0 = this.computationGraph["targeted_sum_0_0_0"];
368-
var targetedSum1 = this.computationGraph["targeted_sum_1_0_0"];
369-
370-
var glyph = this.computationGraph["glyph_0_0"] as ElementwiseVectorCartesianTiledOperation;
371-
372-
var targetedResult0 = (targetedSum0 as ElementwiseVectorCartesianTargetedSumOperation).Backward(gradient0).Results[0] as Matrix;
373-
var targetedResult1 = (targetedSum1 as ElementwiseVectorCartesianTargetedSumOperation).Backward(gradient1).Results[0] as Matrix;
374-
375-
gg = new Matrix(CommonMatrixUtils.InitializeZeroMatrix(225, 2).ToArray());
376-
for (int i = 0; i < 225; i++)
377-
{
378-
gg[i, 0] = outputResult[i, 0] + targetedResult0[i, 0] + targetedResult1[i, 0];
379-
gg[i, 1] = outputResult[i, 1] + targetedResult0[i, 1] + targetedResult1[i, 1];
380-
}
381-
382-
for (int i = 0; i < 225; i++)
383-
{
384-
gg[i, 0] = targetedResult1[i, 0];
385-
gg[i, 1] = targetedResult1[i, 1];
386-
}
387-
388-
backwardStartOperation = glyph;
389-
}
390-
else
391-
{
392-
backwardStartOperation = output;
393-
}
363+
backwardStartOperation = output;
394364

395365
if (!CommonMatrixUtils.IsAllZeroes(gradient))
396366
{

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/VectorNetwork/RMAD/ElementwiseVectorConstituentTiledMultiplyOperation.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ public Matrix Forward(Matrix input1, Matrix input2, Matrix weights)
5050
this.input2 = brokenInput2;
5151
this.weights = brokenWeights;
5252
this.calculatedValues = new CalculatedValues[brokenInput1.GetLength(0), brokenInput2.GetLength(1)][,];
53+
this.output = new Matrix[brokenInput1.GetLength(0), brokenInput2.GetLength(1)];
54+
this.sumX = new Matrix[brokenInput1.GetLength(0), brokenInput2.GetLength(1)];
55+
this.sumY = new Matrix[brokenInput1.GetLength(0), brokenInput2.GetLength(1)];
5356

5457
Parallel.For(0, brokenInput1.GetLength(0), i =>
5558
{

0 commit comments

Comments
 (0)