Skip to content

Commit 7ed3a7b

Browse files
committed
Added T plus some additional test over the linear module.
1 parent 94ed5a1 commit 7ed3a7b

File tree

5 files changed

+143
-30
lines changed

5 files changed

+143
-30
lines changed

Test/TorchSharp/TorchSharp.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,57 @@ public void TestSetGetBiasInLinear()
302302
Assert.AreEqual(lin.Bias.NumberOfElements, bias.NumberOfElements);
303303
}
304304

305+
[TestMethod]
306+
public void TestWeightAndBiasShapeInLinear()
307+
{
308+
var lin = NN.Module.Linear(1000, 100, true);
309+
310+
Assert.AreEqual(lin.Weight.Shape.Length, 2);
311+
Assert.AreEqual(lin.Weight.Shape[0], 100);
312+
Assert.AreEqual(lin.Weight.Shape[1], 1000);
313+
Assert.AreEqual(lin.Bias.Shape.Length, 1);
314+
Assert.AreEqual(lin.Bias.Shape[0], 100);
315+
}
316+
317+
[TestMethod]
318+
public void TestLinearWithBias()
319+
{
320+
var lin = NN.Module.Linear(1000, 100, true);
321+
var bias = lin.Bias;
322+
var weight = lin.Weight.T();
323+
var input = FloatTensor.RandomN(new long[] { 1, 1000 });
324+
var forward = lin.Forward(input);
325+
var matmul = input.MatMul(weight).Add(bias);
326+
327+
Assert.AreEqual(forward.Shape.Length, matmul.Shape.Length);
328+
Assert.AreEqual(forward.Shape[0], matmul.Shape[0]);
329+
Assert.AreEqual(forward.Shape[1], matmul.Shape[1]);
330+
331+
for (int i = 0; i < 100; i++)
332+
{
333+
Assert.AreEqual(forward.Data[i], matmul.Data[i]);
334+
}
335+
}
336+
337+
[TestMethod]
338+
public void TestLinearNoBias()
339+
{
340+
var lin = NN.Module.Linear(1000, 100, false);
341+
var weight = lin.Weight.Transpose(0, 1);
342+
var input = FloatTensor.RandomN(new long[] { 1, 1000 });
343+
var forward = lin.Forward(input);
344+
var matmul = input.MatMul(weight);
345+
346+
Assert.AreEqual(forward.Shape.Length, matmul.Shape.Length);
347+
Assert.AreEqual(forward.Shape[0], matmul.Shape[0]);
348+
Assert.AreEqual(forward.Shape[1], matmul.Shape[1]);
349+
350+
for (int i = 0; i < 100; i++)
351+
{
352+
Assert.AreEqual(forward.Data[i], matmul.Data[i]);
353+
}
354+
}
355+
305356
[TestMethod]
306357
public void CreateRelu()
307358
{

TorchSharp/NN/Module.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public partial class Module
1111
/// <summary>
1212
/// Class wrapping PyTorch's module object reference.
1313
/// </summary>
14-
protected sealed class HType : SafeHandle
14+
internal sealed class HType : SafeHandle
1515
{
1616
public HType(IntPtr preexistingHandle, bool ownsHandle) : base(IntPtr.Zero, ownsHandle)
1717
{
@@ -43,7 +43,7 @@ protected override void Dispose(bool disposing)
4343
}
4444
}
4545

46-
protected HType handle;
46+
internal HType handle;
4747

4848
protected bool _isTraining = true;
4949

TorchSharp/Tensor/ITorchTensor.cs

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
namespace TorchSharp.Tensor
44
{
5-
public interface ITorchTensor<T> : IDisposable
5+
public interface ITorchTensor<U> : IDisposable
66
{
77
IntPtr Handle { get; }
88

@@ -14,15 +14,15 @@ public interface ITorchTensor<T> : IDisposable
1414

1515
string Device { get; }
1616

17-
Span<T> Data { get; }
17+
Span<U> Data { get; }
1818

19-
T DataItem { get; }
19+
U DataItem { get; }
2020

21-
ITorchTensor<T> this[long i1] { get; }
21+
ITorchTensor<U> this[long i1] { get; }
2222

23-
ITorchTensor<T> this[long i1, long i2] { get; }
23+
ITorchTensor<U> this[long i1, long i2] { get; }
2424

25-
ITorchTensor<T> this[long i1, long i2, long i3] { get; }
25+
ITorchTensor<U> this[long i1, long i2, long i3] { get; }
2626

2727
bool IsSparse { get; }
2828

@@ -32,56 +32,62 @@ public interface ITorchTensor<T> : IDisposable
3232

3333
long GetTensorStride(int dim);
3434

35-
ITorchTensor<T> Cpu();
35+
ITorchTensor<U> Cpu();
3636

37-
ITorchTensor<T> Cuda();
37+
ITorchTensor<U> Cuda();
3838

3939
void Backward();
4040

4141
ITorchTensor<float> Grad();
4242

43-
ITorchTensor<T> Reshape(params long[] shape);
43+
ITorchTensor<U> Reshape(params long[] shape);
4444

45-
ITorchTensor<T> View(params long[] shape);
45+
ITorchTensor<U> T();
46+
47+
ITorchTensor<U> Transpose(long dimension1, long dimension2);
48+
49+
void TransposeInPlace(long dimension1, long dimension2);
50+
51+
ITorchTensor<U> View(params long[] shape);
4652

4753
ITorchTensor<U> Eq<U>(ITorchTensor<U> target);
4854

4955
bool Equal<U>(ITorchTensor<U> target);
5056

51-
ITorchTensor<T> Add(ITorchTensor<T> target, int scalar);
57+
ITorchTensor<U> Add(ITorchTensor<U> target, int scalar = 1);
5258

53-
void AddInPlace(ITorchTensor<T> target, int scalar);
59+
void AddInPlace(ITorchTensor<U> target, int scalar);
5460

55-
ITorchTensor<T> Addbmm(ITorchTensor<T> batch1, ITorchTensor<T> batch2, float beta, float alpha);
61+
ITorchTensor<U> Addbmm(ITorchTensor<U> batch1, ITorchTensor<U> batch2, float beta, float alpha);
5662

57-
ITorchTensor<T> Argmax(long dimension, bool keepDimension = false);
63+
ITorchTensor<U> Argmax(long dimension, bool keepDimension = false);
5864

59-
ITorchTensor<T> Baddbmm(ITorchTensor<T> batch2, ITorchTensor<T> mat, float beta, float alpha);
65+
ITorchTensor<U> Baddbmm(ITorchTensor<U> batch2, ITorchTensor<U> mat, float beta, float alpha);
6066

61-
ITorchTensor<T> Bmm(ITorchTensor<T> batch2);
67+
ITorchTensor<U> Bmm(ITorchTensor<U> batch2);
6268

63-
ITorchTensor<T> Exp();
69+
ITorchTensor<U> Exp();
6470

65-
ITorchTensor<T> MatMul(ITorchTensor<T> target);
71+
ITorchTensor<U> MatMul(ITorchTensor<U> target);
6672

67-
ITorchTensor<T> Mean();
73+
ITorchTensor<U> Mean();
6874

69-
ITorchTensor<T> Mm(ITorchTensor<T> target);
75+
ITorchTensor<U> Mm(ITorchTensor<U> target);
7076

71-
ITorchTensor<T> Mul(ITorchTensor<T> target);
77+
ITorchTensor<U> Mul(ITorchTensor<U> target);
7278

73-
ITorchTensor<T> Mul(T scalar);
79+
ITorchTensor<U> Mul(U scalar);
7480

75-
void MulInPlace(ITorchTensor<T> target);
81+
void MulInPlace(ITorchTensor<U> target);
7682

77-
ITorchTensor<T> Pow(float scalar);
83+
ITorchTensor<U> Pow(float scalar);
7884

79-
ITorchTensor<T> Sigmoid();
85+
ITorchTensor<U> Sigmoid();
8086

81-
ITorchTensor<T> Sub(ITorchTensor<T> target);
87+
ITorchTensor<U> Sub(ITorchTensor<U> target);
8288

83-
void SubInPlace(ITorchTensor<T> target);
89+
void SubInPlace(ITorchTensor<U> target);
8490

85-
ITorchTensor<T> Sum();
91+
ITorchTensor<U> Sum();
8692
}
8793
}

TorchSharp/Tensor/TorchTensor.generated.cs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,14 @@ public ITorchTensor<byte> Reshape(params long[] shape)
398398
}
399399
}
400400

401+
[DllImport("libTorchSharp")]
402+
extern static IntPtr THSTensor_t(IntPtr src);
403+
404+
public ITorchTensor<byte> T()
405+
{
406+
return new ByteTensor (THSTensor_t (handle));
407+
}
408+
401409
[DllImport("libTorchSharp")]
402410
extern static IntPtr THSTensor_transpose(IntPtr src, long dim1, long dim2);
403411

@@ -1034,6 +1042,14 @@ public ITorchTensor<short> Reshape(params long[] shape)
10341042
}
10351043
}
10361044

1045+
[DllImport("libTorchSharp")]
1046+
extern static IntPtr THSTensor_t(IntPtr src);
1047+
1048+
public ITorchTensor<short> T()
1049+
{
1050+
return new ShortTensor (THSTensor_t (handle));
1051+
}
1052+
10371053
[DllImport("libTorchSharp")]
10381054
extern static IntPtr THSTensor_transpose(IntPtr src, long dim1, long dim2);
10391055

@@ -1670,6 +1686,14 @@ public ITorchTensor<int> Reshape(params long[] shape)
16701686
}
16711687
}
16721688

1689+
[DllImport("libTorchSharp")]
1690+
extern static IntPtr THSTensor_t(IntPtr src);
1691+
1692+
public ITorchTensor<int> T()
1693+
{
1694+
return new IntTensor (THSTensor_t (handle));
1695+
}
1696+
16731697
[DllImport("libTorchSharp")]
16741698
extern static IntPtr THSTensor_transpose(IntPtr src, long dim1, long dim2);
16751699

@@ -2306,6 +2330,14 @@ public ITorchTensor<long> Reshape(params long[] shape)
23062330
}
23072331
}
23082332

2333+
[DllImport("libTorchSharp")]
2334+
extern static IntPtr THSTensor_t(IntPtr src);
2335+
2336+
public ITorchTensor<long> T()
2337+
{
2338+
return new LongTensor (THSTensor_t (handle));
2339+
}
2340+
23092341
[DllImport("libTorchSharp")]
23102342
extern static IntPtr THSTensor_transpose(IntPtr src, long dim1, long dim2);
23112343

@@ -2942,6 +2974,14 @@ public ITorchTensor<double> Reshape(params long[] shape)
29422974
}
29432975
}
29442976

2977+
[DllImport("libTorchSharp")]
2978+
extern static IntPtr THSTensor_t(IntPtr src);
2979+
2980+
public ITorchTensor<double> T()
2981+
{
2982+
return new DoubleTensor (THSTensor_t (handle));
2983+
}
2984+
29452985
[DllImport("libTorchSharp")]
29462986
extern static IntPtr THSTensor_transpose(IntPtr src, long dim1, long dim2);
29472987

@@ -3578,6 +3618,14 @@ public ITorchTensor<float> Reshape(params long[] shape)
35783618
}
35793619
}
35803620

3621+
[DllImport("libTorchSharp")]
3622+
extern static IntPtr THSTensor_t(IntPtr src);
3623+
3624+
public ITorchTensor<float> T()
3625+
{
3626+
return new FloatTensor (THSTensor_t (handle));
3627+
}
3628+
35813629
[DllImport("libTorchSharp")]
35823630
extern static IntPtr THSTensor_transpose(IntPtr src, long dim1, long dim2);
35833631

TorchSharp/Tensor/TorchTensor.tt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,14 @@ if (!type.IsLong) {
418418
}
419419
}
420420

421+
[DllImport("libTorchSharp")]
422+
extern static IntPtr THSTensor_t(IntPtr src);
423+
424+
public ITorchTensor<<#=type.Storage#>> T()
425+
{
426+
return new <#=type.Name#>Tensor (THSTensor_t (handle));
427+
}
428+
421429
[DllImport("libTorchSharp")]
422430
extern static IntPtr THSTensor_transpose(IntPtr src, long dim1, long dim2);
423431

0 commit comments

Comments
 (0)