Skip to content

Commit 9454a60

Browse files
committed
Merge branch 'LibTorchSharpFirstTest' of https://github.com/interesaaat/TorchSharp into LibTorchSharpFirstTest
2 parents fcf8bc1 + 7c2f248 commit 9454a60

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

Test/TorchSharp.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,31 @@ public void TestGrad()
351351
}
352352
}
353353

354+
[TestMethod]
355+
public void TestAutoGradMode()
356+
{
357+
var x = FloatTensor.RandomN(new long[] { 2, 3 }, device: "cpu:0", requiresGrad: true);
358+
using (var mode = new AutoGradMode(false))
359+
{
360+
var sum = x.Sum();
361+
sum.Backward();
362+
var grad = x.Grad();
363+
Assert.IsTrue(grad.Handle == IntPtr.Zero);
364+
}
365+
using (var mode = new AutoGradMode(true))
366+
{
367+
var sum = x.Sum();
368+
sum.Backward();
369+
var grad = x.Grad();
370+
Assert.IsFalse(grad.Handle == IntPtr.Zero);
371+
var data = grad.Data;
372+
for (int i = 0; i < 2 * 3; i++)
373+
{
374+
Assert.AreEqual(data[i], 1.0);
375+
}
376+
}
377+
}
378+
354379
[TestMethod]
355380
public void TestSubInPlace()
356381
{

TorchSharp/Torch.cs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Runtime.InteropServices;
1+
using System;
2+
using System.Runtime.InteropServices;
23

34
namespace TorchSharp
45
{
@@ -12,4 +13,35 @@ public static void SetSeed(long seed)
1213
NN_Seed(seed);
1314
}
1415
}
16+
17+
public class AutoGradMode : IDisposable
18+
{
19+
[DllImport("LibTorchSharp")]
20+
extern static bool THS_gradmode_is_enabled();
21+
22+
[DllImport("LibTorchSharp")]
23+
extern static void THS_gradmode_set_enabled(bool enabled);
24+
25+
public AutoGradMode(bool enabled)
26+
{
27+
prev_mode = THS_gradmode_is_enabled();
28+
THS_gradmode_set_enabled(enabled);
29+
}
30+
31+
public void Dispose()
32+
{
33+
Dispose(true);
34+
GC.SuppressFinalize(this);
35+
}
36+
37+
public void Dispose(bool disposing)
38+
{
39+
if (disposing)
40+
{
41+
THS_gradmode_set_enabled(prev_mode);
42+
}
43+
}
44+
45+
bool prev_mode;
46+
}
1547
}

0 commit comments

Comments
 (0)