Skip to content
This repository was archived by the owner on Aug 11, 2025. It is now read-only.

Commit 7947e7f

Browse files
Dorokhovmigueldeicaza
authored andcommitted
Gradient support (#159)
These changes and the fix in native API resolves issue #25 Now, add gradients works correctly The test will fail until a new version of TensorFlowSharp with updated native libraries is released.
1 parent 240d592 commit 7947e7f

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

TensorFlowSharp/Tensorflow.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,10 +1003,10 @@ public TFOutput [] AddGradients (TFOutput [] y, TFOutput [] x, TFOutput [] dx =
10031003
fixed (TFOutput* py = &y [0]) {
10041004
fixed (TFOutput* px = &x [0]) {
10051005
if (dx == null) {
1006-
TF_AddGradients (handle, py, y.Length, px, x.Length, (TFOutput*)null, status.Handle, pret);
1006+
TF_AddGradients (handle, py, y.Length, px, x.Length, (TFOutput*)null, cstatus.Handle, pret);
10071007
} else {
10081008
fixed (TFOutput* pdx = &dx [0]) {
1009-
TF_AddGradients (handle, py, y.Length, px, x.Length, pdx, status.Handle, pret);
1009+
TF_AddGradients (handle, py, y.Length, px, x.Length, pdx, cstatus.Handle, pret);
10101010
}
10111011
}
10121012
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using TensorFlow;
2+
using Xunit;
3+
4+
namespace TensorFlowSharp.Tests.CSharp
5+
{
6+
public class GradientTests
7+
{
8+
[Fact]
9+
public void ShouldAddGradients ()
10+
{
11+
using (var graph = new TFGraph ())
12+
using (var session = new TFSession (graph)) {
13+
var x = graph.Const (3.0);
14+
15+
var y1 = graph.Square (x, "Square1");
16+
var y2 = graph.Square (y1, "Square2");
17+
18+
var y3 = graph.Square (y2, "Square3");
19+
var g = graph.AddGradients (new TFOutput [] { y1, y3 }, new [] { x });
20+
21+
var r = session.Run (new TFOutput [] { }, new TFTensor [] { }, g);
22+
var dy = (double)r [0].GetValue ();
23+
Assert.Equal (17502.0, dy);
24+
}
25+
}
26+
}
27+
}

tests/TensorFlowSharp.Tests.CSharp/TensorFlowSharp.Tests.CSharp.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
</Reference>
6464
</ItemGroup>
6565
<ItemGroup>
66+
<Compile Include="GradientTests.cs" />
6667
<Compile Include="ArrayTests.cs" />
6768
<Compile Include="TensorTests.cs" />
6869
<Compile Include="ClipTests.cs" />

0 commit comments

Comments
 (0)