Skip to content

Commit 7759c63

Browse files
committed
Refactoring according to the libTorchSharp side.
1 parent 9454a60 commit 7759c63

22 files changed

+588
-434
lines changed

Test/TorchSharp.cs

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -133,31 +133,46 @@ public void CopyCpuToCuda()
133133
var cpu = FloatTensor.Ones(new long[] { 2, 2 });
134134
Assert.AreEqual(cpu.Device, "cpu");
135135

136-
var cuda = cpu.Cuda();
137-
Assert.AreEqual(cuda.Device, "cuda");
136+
if (Torch.IsCudaAvailable())
137+
{
138+
var cuda = cpu.Cuda();
139+
Assert.AreEqual(cuda.Device, "cuda");
138140

139-
// Copy back to CPU to inspect the elements
140-
cpu = cuda.Cpu();
141-
var data = cpu.Data;
142-
for (int i = 0; i < 4; i++)
141+
// Copy back to CPU to inspect the elements
142+
cpu = cuda.Cpu();
143+
var data = cpu.Data;
144+
for (int i = 0; i < 4; i++)
145+
{
146+
Assert.AreEqual(data[i], 1);
147+
}
148+
}
149+
else
143150
{
144-
Assert.AreEqual(data[i], 1);
151+
Assert.ThrowsException<InvalidOperationException>(cpu.Cuda);
145152
}
153+
146154
}
147155

148156
[TestMethod]
149157
public void CopyCudaToCpu()
150158
{
151-
var cuda = FloatTensor.Ones(new long[] { 2, 2 }, "cuda");
152-
Assert.AreEqual(cuda.Device, "cuda");
159+
if (Torch.IsCudaAvailable())
160+
{
161+
var cuda = FloatTensor.Ones(new long[] { 2, 2 }, "cuda");
162+
Assert.AreEqual(cuda.Device, "cuda");
153163

154-
var cpu = cuda.Cpu();
155-
Assert.AreEqual(cpu.Device, "cpu");
164+
var cpu = cuda.Cpu();
165+
Assert.AreEqual(cpu.Device, "cpu");
156166

157-
var data = cpu.Data;
158-
for (int i = 0; i < 4; i++)
167+
var data = cpu.Data;
168+
for (int i = 0; i < 4; i++)
169+
{
170+
Assert.AreEqual(data[i], 1);
171+
}
172+
}
173+
else
159174
{
160-
Assert.AreEqual(data[i], 1);
175+
Assert.ThrowsException<InvalidOperationException>(() => { FloatTensor.Ones(new long[] { 2, 2 }, "cuda"); });
161176
}
162177
}
163178

@@ -270,7 +285,7 @@ public void EvalLossSequence()
270285
var y = FloatTensor.RandomN(new long[] { 64, 10 }, device: "cpu:0");
271286

272287
var eval = seq.Forward(x);
273-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.None);
288+
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
274289

275290
var result = loss.Item;
276291
Assert.IsNotNull(result);
@@ -357,13 +372,15 @@ public void TestAutoGradMode()
357372
var x = FloatTensor.RandomN(new long[] { 2, 3 }, device: "cpu:0", requiresGrad: true);
358373
using (var mode = new AutoGradMode(false))
359374
{
375+
Assert.IsFalse(AutoGradMode.IsAutogradEnabled());
360376
var sum = x.Sum();
361377
sum.Backward();
362378
var grad = x.Grad();
363379
Assert.IsTrue(grad.Handle == IntPtr.Zero);
364380
}
365381
using (var mode = new AutoGradMode(true))
366382
{
383+
Assert.IsTrue(AutoGradMode.IsAutogradEnabled());
367384
var sum = x.Sum();
368385
sum.Backward();
369386
var grad = x.Grad();
@@ -435,7 +452,7 @@ public void TestTraining()
435452
for (int i = 0; i < 10; i++)
436453
{
437454
var eval = seq.Forward(x);
438-
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.None);
455+
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
439456
var lossVal = loss.Item;
440457

441458
Assert.IsTrue(lossVal < prevLoss);
@@ -445,13 +462,14 @@ public void TestTraining()
445462

446463
loss.Backward();
447464

448-
// using(var noGrad = NN.NoGrad())
449-
// The operators Mul and SubInPlace have no_grad=true by default
450-
foreach (var param in seq.Parameters())
465+
using (var noGrad = new AutoGradMode(false))
451466
{
452-
var grad = param.Grad();
453-
var update = grad.Mul(learning_rate);
454-
param.SubInPlace(update);
467+
foreach (var param in seq.Parameters())
468+
{
469+
var grad = param.Grad();
470+
var update = grad.Mul(learning_rate);
471+
param.SubInPlace(update);
472+
}
455473
}
456474
}
457475
}

TorchSharp/Autograd.cs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using System;
2+
using System.Runtime.InteropServices;
3+
4+
namespace TorchSharp
5+
{
6+
public sealed class AutoGradMode : IDisposable
7+
{
8+
private bool _isPrevGrad;
9+
10+
[DllImport("LibTorchSharp")]
11+
extern static bool THSAutograd_isGradEnabled();
12+
13+
[DllImport("LibTorchSharp")]
14+
extern static void THSAutograd_setGrad(bool enabled);
15+
16+
public AutoGradMode(bool enabled)
17+
{
18+
_isPrevGrad = THSAutograd_isGradEnabled();
19+
THSAutograd_setGrad(enabled);
20+
}
21+
22+
public void Dispose()
23+
{
24+
Dispose(true);
25+
GC.SuppressFinalize(this);
26+
}
27+
28+
public void Dispose(bool disposing)
29+
{
30+
if (disposing)
31+
{
32+
THSAutograd_setGrad(_isPrevGrad);
33+
}
34+
}
35+
36+
public static bool IsAutogradEnabled()
37+
{
38+
return THSAutograd_isGradEnabled();
39+
}
40+
}
41+
}

TorchSharp/Data/DataIterator.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@ namespace TorchSharp.Data
1313
internal static class ExternMethods
1414
{
1515
[DllImport("libTorchSharp")]
16-
extern internal static IntPtr Data_Current(IntPtr iterator, IntPtr data, IntPtr target);
16+
extern internal static IntPtr THSData_current(IntPtr iterator, IntPtr data, IntPtr target);
1717

1818
[DllImport("libTorchSharp")]
19-
extern internal static bool Data_MoveNext(IntPtr iterator);
19+
extern internal static bool THSData_moveNext(IntPtr iterator);
2020

2121
[DllImport("libTorchSharp")]
22-
extern internal static long Data_Size(IntPtr iterator);
22+
extern internal static long THSData_size(IntPtr iterator);
2323

2424
[DllImport("libTorchSharp")]
25-
extern internal static void Data_Reset(IntPtr iterator);
25+
extern internal static void THSData_reset(IntPtr iterator);
2626

2727
[DllImport("libTorchSharp")]
28-
extern internal static void Data_Dispose(IntPtr iterator);
28+
extern internal static void THSData_dispose(IntPtr iterator);
2929
}
3030

3131
/// <summary>
@@ -100,7 +100,7 @@ protected void Dispose(bool disposing)
100100
{
101101
if (disposing)
102102
{
103-
ExternMethods.Data_Dispose(handle.DangerousGetHandle());
103+
ExternMethods.THSData_dispose(handle.DangerousGetHandle());
104104
handle.Dispose();
105105
handle.SetHandleAsInvalid();
106106
}
@@ -112,7 +112,7 @@ protected void Dispose(bool disposing)
112112
/// <returns></returns>
113113
public long Size()
114114
{
115-
return ExternMethods.Data_Size(handle.DangerousGetHandle());
115+
return ExternMethods.THSData_size(handle.DangerousGetHandle());
116116
}
117117

118118
/// <summary>
@@ -156,7 +156,7 @@ public DataIteratorEnumerator(DataIterator<TData, TTarget> iterator)
156156
{
157157
get
158158
{
159-
ExternMethods.Data_Current(_iterator.handle.DangerousGetHandle(), _dRef, _tRef);
159+
ExternMethods.THSData_current(_iterator.handle.DangerousGetHandle(), _dRef, _tRef);
160160
return (_darray.Array[0].ToTorchTensor<TData>(), _tarray.Array[0].ToTorchTensor<TTarget>());
161161
}
162162
}
@@ -171,13 +171,13 @@ public bool MoveNext()
171171
return true;
172172
}
173173

174-
return ExternMethods.Data_MoveNext(_iterator.handle.DangerousGetHandle());
174+
return ExternMethods.THSData_moveNext(_iterator.handle.DangerousGetHandle());
175175
}
176176

177177
public void Reset()
178178
{
179179
_isFirst = true;
180-
ExternMethods.Data_Reset(_iterator.handle.DangerousGetHandle());
180+
ExternMethods.THSData_reset(_iterator.handle.DangerousGetHandle());
181181
}
182182

183183
public void Dispose()

TorchSharp/Data/Loader.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace TorchSharp.Data
66
public class Loader
77
{
88
[DllImport("libTorchSharp")]
9-
extern static IntPtr Data_LoaderMNIST(string filename, long batchSize, bool isTrain);
9+
extern static IntPtr THSData_loaderMNIST(string filename, long batchSize, bool isTrain);
1010

1111
/// <summary>
1212
/// Create an iterator scanning the MNIST dataset.
@@ -17,7 +17,7 @@ public class Loader
1717
/// <returns></returns>
1818
static public DataIterator<int, int> MNIST(string filename, long batchSize, bool isTrain = true)
1919
{
20-
return new DataIterator<int, int>(Data_LoaderMNIST(filename, batchSize, isTrain));
20+
return new DataIterator<int, int>(THSData_loaderMNIST(filename, batchSize, isTrain));
2121
}
2222
}
2323
}

TorchSharp/JIT/Type/TensorType.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,27 @@ internal TensorType(Type type) : base()
1818
}
1919

2020
[DllImport("libTorchSharp")]
21-
extern static short JIT_TensorType_getScalar(HType handle);
21+
extern static short THSJIT_getScalarFromTensorType(HType handle);
2222

2323
public Tensor.ATenScalarMapping GetScalarType()
2424
{
25-
return (Tensor.ATenScalarMapping)JIT_TensorType_getScalar(handle);
25+
return (Tensor.ATenScalarMapping)THSJIT_getScalarFromTensorType(handle);
2626
}
2727

2828
[DllImport("libTorchSharp")]
29-
extern static int JIT_TensorType_getDimensions(HType handle);
29+
extern static int THSJIT_getTensorTypeDimensions(HType handle);
3030

3131
public int GetDimensions()
3232
{
33-
return JIT_TensorType_getDimensions(handle);
33+
return THSJIT_getTensorTypeDimensions(handle);
3434
}
3535

3636
[DllImport("libTorchSharp")]
37-
extern static string JIT_TensorType_getDevice(HType handle);
37+
extern static string THSJIT_getTensorDevice(HType handle);
3838

3939
public string GetDevice()
4040
{
41-
return JIT_TensorType_getDevice(handle);
41+
return THSJIT_getTensorDevice(handle);
4242
}
4343
}
4444
}

TorchSharp/JIT/Type/Type.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ internal HType() : base(IntPtr.Zero, true)
2323
}
2424

2525
[DllImport("libTorchSharp")]
26-
extern static void JIT_Type_Dispose(HType handle);
26+
extern static void THSJIT_typeDispose(HType handle);
2727

2828
protected override bool ReleaseHandle()
2929
{
30-
JIT_Type_Dispose(this);
30+
THSJIT_typeDispose(this);
3131
return true;
3232
}
3333

@@ -78,11 +78,11 @@ protected void Dispose(bool disposing)
7878
}
7979

8080
[DllImport("libTorchSharp")]
81-
extern static short JIT_TypeKind(HType handle);
81+
extern static short THSJIT_typeKind(HType handle);
8282

8383
internal TypeKind Kind
8484
{
85-
get { return (TypeKind)JIT_TypeKind(handle); }
85+
get { return (TypeKind)THSJIT_typeKind(handle); }
8686
}
8787

8888
internal TensorType AsTensorType()

TorchSharp/NN/Conv2D.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ internal Conv2D(IntPtr handle) : base(handle)
1111
}
1212

1313
[DllImport("libTorchSharp")]
14-
extern static IntPtr NN_conv2DModule_Forward(Module.HType module, IntPtr tensor);
14+
extern static IntPtr THSNN_conv2DModuleApply(Module.HType module, IntPtr tensor);
1515

1616
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
1717
{
18-
return new FloatTensor(NN_conv2DModule_Forward(handle, tensor.Handle));
18+
return new FloatTensor(THSNN_conv2DModuleApply(handle, tensor.Handle));
1919
}
2020
}
2121
}

TorchSharp/NN/Dropout.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ internal Dropout(double probability, bool isTraining) : base()
1818
}
1919

2020
[DllImport("libTorchSharp")]
21-
extern static IntPtr NN_DropoutModule_Forward(IntPtr tensor, double probability, bool isTraining);
21+
extern static IntPtr THSNN_dropoutModuleApply(IntPtr tensor, double probability, bool isTraining);
2222

2323
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
2424
{
25-
return new FloatTensor(NN_DropoutModule_Forward(tensor.Handle, _probability, _isTraining));
25+
return new FloatTensor(THSNN_dropoutModuleApply(tensor.Handle, _probability, _isTraining));
2626
}
2727
}
2828
}

TorchSharp/NN/FeatureDropout.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ internal FeatureDropout() : base()
1414
}
1515

1616
[DllImport("libTorchSharp")]
17-
extern static IntPtr NN_FeatureDropout_Forward(IntPtr tensor);
17+
extern static IntPtr THSNN_featureDropoutApply(IntPtr tensor);
1818

1919
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
2020
{
21-
return new FloatTensor(NN_FeatureDropout_Forward(tensor.Handle));
21+
return new FloatTensor(THSNN_featureDropoutApply(tensor.Handle));
2222
}
2323
}
2424
}

TorchSharp/NN/Linear.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ internal Linear(IntPtr handle) : base(handle)
1111
}
1212

1313
[DllImport("libTorchSharp")]
14-
extern static IntPtr NN_linearModule_Forward(Module.HType module, IntPtr tensor);
14+
extern static IntPtr THSNN_linearModuleApply(Module.HType module, IntPtr tensor);
1515

1616
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
1717
{
18-
return new FloatTensor(NN_linearModule_Forward(handle, tensor.Handle));
18+
return new FloatTensor(THSNN_linearModuleApply(handle, tensor.Handle));
1919
}
2020
}
2121
}

0 commit comments

Comments
 (0)