Skip to content

Commit 7e9ff48

Browse files
committed
Update tiled net.
1 parent edb3a33 commit 7e9ff48

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/TiledNet.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public void AdjustLearningRate(double learningRate)
8181
/// <returns>The task.</returns>
8282
public async Task Initialize()
8383
{
84-
var initialAdamIteration = 1;
84+
var initialAdamIteration = 462;
8585
var model = new TiledNetwork.TiledNetwork(this.numLayers, this.numNodes, this.numFeatures, this.learningRate, this.clipValue, "tilednet");
8686
model.Parameters.AdamIteration = initialAdamIteration;
8787
this.TiledNetwork = model;
@@ -122,7 +122,7 @@ public void SaveWeights()
122122
/// </summary>
123123
public void ApplyWeights()
124124
{
125-
var guid = "tiled_05842a36-7e13-4ba8-a981-df0c5a2b50c5_11";
125+
var guid = "tiled_0a7edd07-8034-4d91-b7fd-68fb84a22a48_462";
126126
var dir = $"E:\\vnnstore\\{guid}";
127127
for (int i = 0; i < this.modelLayers.Count; ++i)
128128
{

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/TiledNetTrainer.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public async Task Train()
1414
CudaBlas.Instance.Initialize();
1515
TiledNet net = new TiledNet(512, 6144, 3, 0.01d, 4d);
1616
await net.Initialize();
17-
//net.ApplyWeights();
17+
net.ApplyWeights();
1818

1919
var pngFiles = Directory.GetFiles(@"E:\images\inputs\svg", "*.png");
2020

@@ -106,18 +106,20 @@ await files.WithRepeatAsync(async (pngFile, token) =>
106106
//}
107107

108108

109-
//Console.WriteLine($"Iteration {i} Output X: {output[0, 0]}, Output Y: {output[0, 1]}, Grad: {gradient[0, 0]}, {gradient[0, 1]}");
110-
//Console.WriteLine($"Loss: {loss[0, 0]}, Perc: {perc}");
109+
Console.WriteLine($"Iteration {i}");
110+
Console.WriteLine($"Loss: {lossAndGradient[0, 0].Item1[0, 0]}, {lossAndGradient[0, 1].Item1[0, 0]}, {lossAndGradient[0, 2].Item1[0, 0]}");
111+
Console.WriteLine($"Loss: {lossAndGradient[1, 0].Item1[0, 0]}, {lossAndGradient[1, 1].Item1[0, 0]}, {lossAndGradient[1, 2].Item1[0, 0]}");
112+
Console.WriteLine($"Loss: {lossAndGradient[2, 0].Item1[0, 0]}, {lossAndGradient[2, 1].Item1[0, 0]}, {lossAndGradient[2, 2].Item1[0, 0]}");
111113
//Console.WriteLine($"O1 X: {o1[0, 0]}, O1 Y: {o1[0, 1]}, Loss: {loss[0, 0]}, {loss0[0, 0]}, {loss1[0, 0]}");
112-
await net.Backward(lossAndGradient);
113-
net.ApplyGradients();
114+
//await net.Backward(lossAndGradient);
115+
//net.ApplyGradients();
114116
//}
115117

116118
await net.Reset();
117119
Thread.Sleep(1000);
118120
if (i % 11 == 10)
119121
{
120-
net.SaveWeights();
122+
//net.SaveWeights();
121123
}
122124

123125
//if (token.UsageCount == 0)

0 commit comments

Comments
 (0)