Skip to content

Commit 7c2f248

Browse files
authored
Merge pull request #2 from gyeongin/auto-grad-mode
Implement AutoGradMode for TorchSharp
2 parents 6f2e094 + 4f4e5d9 commit 7c2f248

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

Test/TorchSharp.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,31 @@ public void TestGrad()
337337
}
338338
}
339339

340+
[TestMethod]
341+
public void TestAutoGradMode()
342+
{
343+
var x = FloatTensor.RandomN(new long[] { 2, 3 }, device: "cpu:0", requiresGrad: true);
344+
using (var mode = new AutoGradMode(false))
345+
{
346+
var sum = x.Sum();
347+
sum.Backward();
348+
var grad = x.Grad();
349+
Assert.IsTrue(grad.Handle == IntPtr.Zero);
350+
}
351+
using (var mode = new AutoGradMode(true))
352+
{
353+
var sum = x.Sum();
354+
sum.Backward();
355+
var grad = x.Grad();
356+
Assert.IsFalse(grad.Handle == IntPtr.Zero);
357+
var data = grad.Data;
358+
for (int i = 0; i < 2 * 3; i++)
359+
{
360+
Assert.AreEqual(data[i], 1.0);
361+
}
362+
}
363+
}
364+
340365
[TestMethod]
341366
public void TestSubInPlace()
342367
{

TorchSharp/Torch.cs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Runtime.InteropServices;
1+
using System;
2+
using System.Runtime.InteropServices;
23

34
namespace TorchSharp
45
{
@@ -12,4 +13,35 @@ public static void SetSeed(long seed)
1213
NN_Seed(seed);
1314
}
1415
}
16+
17+
public class AutoGradMode : IDisposable
18+
{
19+
[DllImport("LibTorchSharp")]
20+
extern static bool THS_gradmode_is_enabled();
21+
22+
[DllImport("LibTorchSharp")]
23+
extern static void THS_gradmode_set_enabled(bool enabled);
24+
25+
public AutoGradMode(bool enabled)
26+
{
27+
prev_mode = THS_gradmode_is_enabled();
28+
THS_gradmode_set_enabled(enabled);
29+
}
30+
31+
public void Dispose()
32+
{
33+
Dispose(true);
34+
GC.SuppressFinalize(this);
35+
}
36+
37+
public void Dispose(bool disposing)
38+
{
39+
if (disposing)
40+
{
41+
THS_gradmode_set_enabled(prev_mode);
42+
}
43+
}
44+
45+
bool prev_mode;
46+
}
1547
}

0 commit comments

Comments
 (0)