Skip to content

Commit 3f57f43

Browse files
committed
Removed type information from tensors.
1 parent 14e90a4 commit 3f57f43

25 files changed

+1798
-4909
lines changed

Test/TorchSharp/TorchSharp.cs

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using System.Linq;
44
using TorchSharp.JIT;
5+
using TorchSharp.NN;
56
using TorchSharp.Tensor;
67

78
namespace TorchSharp.Test
@@ -19,7 +20,7 @@ public void CreateFloatTensorOnes()
1920
[TestMethod]
2021
public void CreateFloatTensorCheckMemory()
2122
{
22-
ITorchTensor<float> ones = null;
23+
ITorchTensor ones = null;
2324

2425
for (int i = 0; i < 10; i++)
2526
{
@@ -35,7 +36,7 @@ public void CreateFloatTensorCheckMemory()
3536
public void CreateFloatTensorOnesCheckData()
3637
{
3738
var ones = FloatTensor.Ones(new long[] { 2, 2 });
38-
var data = ones.Data;
39+
var data = ones.Data<float>();
3940

4041
for (int i = 0; i < 4; i++)
4142
{
@@ -47,7 +48,7 @@ public void CreateFloatTensorOnesCheckData()
4748
public void CreateFloatTensorZerosCheckData()
4849
{
4950
var zeros = FloatTensor.Zeros(new long[] { 2, 2 });
50-
var data = zeros.Data;
51+
var data = zeros.Data<float>();
5152

5253
for (int i = 0; i < 4; i++)
5354
{
@@ -59,7 +60,7 @@ public void CreateFloatTensorZerosCheckData()
5960
public void CreateIntTensorOnesCheckData()
6061
{
6162
var ones = IntTensor.Ones(new long[] { 2, 2 });
62-
var data = ones.Data;
63+
var data = ones.Data<int>();
6364

6465
for (int i = 0; i < 4; i++)
6566
{
@@ -84,7 +85,7 @@ public void CreateFloatTensorFromData()
8485

8586
using (var tensor = FloatTensor.From(data, new long[] { 100, 10 }))
8687
{
87-
Assert.AreEqual(tensor.Data[100], 1);
88+
Assert.AreEqual(tensor.Data<float>()[100], 1);
8889
}
8990
}
9091

@@ -96,7 +97,7 @@ public void CreateFloatTensorFromDataCheckDispose()
9697

9798
using (var tensor = FloatTensor.From(data, new long[] { 100, 10 }))
9899
{
99-
Assert.AreEqual(tensor.Data[100], 1);
100+
Assert.AreEqual(tensor.Data<float>()[100], 1);
100101
}
101102

102103
Assert.AreEqual(data[100], 1);
@@ -114,7 +115,7 @@ private static void CreateFloatTensorFromData2Generic<T>()
114115

115116
using (var tensor = data.ToTorchTensor(new long[] { 10, 100 }))
116117
{
117-
Assert.AreEqual(tensor.Data[100], default(T));
118+
Assert.AreEqual(tensor.Data<T>()[100], default(T));
118119
}
119120
}
120121

@@ -125,7 +126,7 @@ public void CreateFloatTensorFromScalar()
125126

126127
using (var tensor = FloatTensor.From(scalar))
127128
{
128-
Assert.AreEqual(tensor.DataItem, 333);
129+
Assert.AreEqual(tensor.DataItem<float>(), 333);
129130
}
130131
}
131132

@@ -136,7 +137,7 @@ public void CreateFloatTensorFromScalar2()
136137

137138
using (var tensor = scalar.ToTorchTensor())
138139
{
139-
Assert.AreEqual(tensor.DataItem, 333);
140+
Assert.AreEqual(tensor.DataItem<float>(), 333);
140141
}
141142
}
142143

@@ -177,7 +178,7 @@ public void CopyCpuToCuda()
177178

178179
// Copy back to CPU to inspect the elements
179180
cpu = cuda.Cpu();
180-
var data = cpu.Data;
181+
var data = cpu.Data<float>();
181182
for (int i = 0; i < 4; i++)
182183
{
183184
Assert.AreEqual(data[i], 1);
@@ -201,7 +202,7 @@ public void CopyCudaToCpu()
201202
var cpu = cuda.Cpu();
202203
Assert.AreEqual(cpu.Device, "cpu");
203204

204-
var data = cpu.Data;
205+
var data = cpu.Data<float>();
205206
for (int i = 0; i < 4; i++)
206207
{
207208
Assert.AreEqual(data[i], 1);
@@ -362,7 +363,7 @@ public void TestLinearWithBias()
362363

363364
for (int i = 0; i < 100; i++)
364365
{
365-
Assert.AreEqual(forward.Data[i], matmul.Data[i]);
366+
Assert.AreEqual(forward.Data<float>()[i], matmul.Data<float>()[i]);
366367
}
367368
}
368369

@@ -381,7 +382,7 @@ public void TestLinearNoBias()
381382

382383
for (int i = 0; i < 100; i++)
383384
{
384-
Assert.AreEqual(forward.Data[i], matmul.Data[i]);
385+
Assert.AreEqual(forward.Data<float>()[i], matmul.Data<float>()[i]);
385386
}
386387
}
387388

@@ -428,17 +429,17 @@ public void EvalLossSequence()
428429
var eval = seq.Forward(x);
429430
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
430431

431-
var result = loss.DataItem;
432+
var result = loss.DataItem<float>();
432433
Assert.IsNotNull(result);
433434
}
434435

435436
[TestMethod]
436437
public void TestPoissonNLLLoss()
437438
{
438-
using (FloatTensor input = FloatTensor.From(new float[] { 0.5f, 1.5f, 2.5f }))
439-
using (FloatTensor target = FloatTensor.From(new float[] { 1f, 2f, 3f }))
439+
using (TorchTensor input = FloatTensor.From(new float[] { 0.5f, 1.5f, 2.5f }))
440+
using (TorchTensor target = FloatTensor.From(new float[] { 1f, 2f, 3f }))
440441
{
441-
var componentWiseLoss = ((FloatTensor)input.Exp()) - target * input;
442+
var componentWiseLoss = ((TorchTensor)input.Exp()) - target * input;
442443
Assert.IsTrue(componentWiseLoss.Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.None)));
443444
Assert.IsTrue(componentWiseLoss.Sum().Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.Sum)));
444445
Assert.IsTrue(componentWiseLoss.Mean().Equal(NN.LossFunction.PoissonNLL(input, target, reduction: NN.Reduction.Mean)));
@@ -539,7 +540,7 @@ public void TestAutoGradMode()
539540
sum.Backward();
540541
var grad = x.Grad();
541542
Assert.IsFalse(grad.Handle == IntPtr.Zero);
542-
var data = grad.Data;
543+
var data = grad.Data<float>();
543544
for (int i = 0; i < 2 * 3; i++)
544545
{
545546
Assert.AreEqual(data[i], 1.0);
@@ -555,7 +556,7 @@ public void TestSubInPlace()
555556

556557
x.SubInPlace(y);
557558

558-
var xdata = x.Data;
559+
var xdata = x.Data<int>();
559560

560561
for (int i = 0; i < 100; i++)
561562
{
@@ -573,8 +574,8 @@ public void TestMul()
573574

574575
var y = x.Mul(0.5f);
575576

576-
var ydata = y.Data;
577-
var xdata = x.Data;
577+
var ydata = y.Data<float>();
578+
var xdata = x.Data<float>();
578579

579580
for (int i = 0; i < 100; i++)
580581
{
@@ -615,12 +616,12 @@ public void TestCustomModuleWithInPlaceModification()
615616

616617
private class TestModule : NN.Module
617618
{
618-
public TestModule(string name, ITorchTensor<float> tensor, bool withGrad)
619-
: base((name, tensor, withGrad))
619+
public TestModule(string name, ITorchTensor tensor, bool withGrad)
620+
: base(new NN.Parameter(name, tensor, withGrad))
620621
{
621622
}
622623

623-
public override ITorchTensor<float> Forward<T>(params ITorchTensor<T>[] tensors)
624+
public override ITorchTensor Forward(ITorchTensor input)
624625
{
625626
throw new NotImplementedException();
626627
}
@@ -648,7 +649,7 @@ public void TestTraining()
648649
{
649650
var eval = seq.Forward(x);
650651
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
651-
var lossVal = loss.DataItem;
652+
var lossVal = loss.DataItem<float>();
652653

653654
Assert.IsTrue(lossVal < prevLoss);
654655
prevLoss = lossVal;
@@ -705,7 +706,7 @@ public void TestTrainingAdam()
705706
{
706707
var eval = seq.Forward(x);
707708
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
708-
var lossVal = loss.DataItem;
709+
var lossVal = loss.DataItem<float>();
709710

710711
Assert.IsTrue(lossVal < prevLoss);
711712
prevLoss = lossVal;

TorchSharp/Data/DataIterator.cs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,9 @@ internal static class ExternMethods
3131
/// <summary>
3232
/// Class implementing enumerable over PyTorch's iterator.
3333
/// </summary>
34-
/// <typeparam name="TData"></typeparam>
35-
/// <typeparam name="TTarget"></typeparam>
36-
public class DataIterator<TData, TTarget> :
34+
public class DataIterator :
3735
IDisposable,
38-
IEnumerable<(ITorchTensor<TData> data, ITorchTensor<TTarget> target)>
36+
IEnumerable<(ITorchTensor data, ITorchTensor target)>
3937
{
4038
/// <summary>
4139
/// Class wrapping PyTorch's iterator object reference.
@@ -119,7 +117,7 @@ public long Size()
119117
/// Get the enumerator for this iterator.
120118
/// </summary>
121119
/// <returns></returns>
122-
public IEnumerator<(ITorchTensor<TData> data, ITorchTensor<TTarget> target)> GetEnumerator()
120+
public IEnumerator<(ITorchTensor data, ITorchTensor target)> GetEnumerator()
123121
{
124122
var iter = new DataIteratorEnumerator(this);
125123
iter.Reset();
@@ -132,9 +130,9 @@ IEnumerator IEnumerable.GetEnumerator()
132130
return GetEnumerator();
133131
}
134132

135-
private class DataIteratorEnumerator : IEnumerator<(ITorchTensor<TData> data, ITorchTensor<TTarget> target)>
133+
private class DataIteratorEnumerator : IEnumerator<(ITorchTensor data, ITorchTensor target)>
136134
{
137-
private DataIterator<TData, TTarget> _iterator;
135+
private DataIterator _iterator;
138136

139137
private readonly PinnedArray<IntPtr> _darray = new PinnedArray<IntPtr>();
140138
private readonly PinnedArray<IntPtr> _tarray = new PinnedArray<IntPtr>();
@@ -144,20 +142,20 @@ private class DataIteratorEnumerator : IEnumerator<(ITorchTensor<TData> data, IT
144142

145143
private bool _isFirst = true;
146144

147-
public DataIteratorEnumerator(DataIterator<TData, TTarget> iterator)
145+
public DataIteratorEnumerator(DataIterator iterator)
148146
{
149147
_iterator = iterator;
150148

151149
_dRef = _darray.CreateArray(new IntPtr[1]);
152150
_tRef = _tarray.CreateArray(new IntPtr[1]);
153151
}
154152

155-
public (ITorchTensor<TData> data, ITorchTensor<TTarget> target) Current
153+
public (ITorchTensor data, ITorchTensor target) Current
156154
{
157155
get
158156
{
159157
ExternMethods.THSData_current(_iterator.handle.DangerousGetHandle(), _dRef, _tRef);
160-
return (_darray.Array[0].ToTorchTensor<TData>(), _tarray.Array[0].ToTorchTensor<TTarget>());
158+
return (new TorchTensor(_darray.Array[0]), new TorchTensor(_tarray.Array[0]));
161159
}
162160
}
163161

TorchSharp/Data/Loader.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ public class Loader
1515
/// <param name="batchSize">The required batch size</param>
1616
/// <param name="isTrain">Wheter the iterator is for training or testing</param>
1717
/// <returns></returns>
18-
static public DataIterator<int, int> MNIST(string filename, long batchSize, bool isTrain = true)
18+
static public DataIterator MNIST(string filename, long batchSize, bool isTrain = true)
1919
{
20-
return new DataIterator<int, int>(THSData_loaderMNIST(filename, batchSize, isTrain));
20+
return new DataIterator(THSData_loaderMNIST(filename, batchSize, isTrain));
2121
}
2222
}
2323
}

TorchSharp/JIT/Module.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,12 @@ private Type GetType(Type type)
158158
[DllImport("libTorchSharp")]
159159
extern static IntPtr THSJIT_forward(Module.HType module, IntPtr tensors, int length);
160160

161-
public ITorchTensor<float> Forward<T>(params ITorchTensor<T>[] tensors)
161+
public ITorchTensor Forward(params ITorchTensor[] tensors)
162162
{
163163
var parray = new PinnedArray<IntPtr>();
164164
IntPtr tensorRefs = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
165165

166-
return new FloatTensor(THSJIT_forward(handle, tensorRefs, parray.Array.Length));
166+
return new TorchTensor(THSJIT_forward(handle, tensorRefs, parray.Array.Length));
167167
}
168168
}
169169
}

TorchSharp/NN/Conv2D.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ internal Conv2D(IntPtr handle) : base(handle)
1313
[DllImport("libTorchSharp")]
1414
extern static IntPtr THSNN_conv2DModuleApply(Module.HType module, IntPtr tensor);
1515

16-
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
16+
public override ITorchTensor Forward(ITorchTensor tensor)
1717
{
18-
return new FloatTensor(THSNN_conv2DModuleApply(handle, tensor.Handle));
18+
return new TorchTensor(THSNN_conv2DModuleApply(handle, tensor.Handle));
1919
}
2020
}
2121
}

TorchSharp/NN/Dropout.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ internal Dropout(double probability, bool isTraining) : base()
2121
[DllImport("libTorchSharp")]
2222
extern static IntPtr THSNN_dropoutModuleApply(IntPtr tensor, double probability, bool isTraining);
2323

24-
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
24+
public override ITorchTensor Forward(ITorchTensor tensor)
2525
{
26-
return new FloatTensor(THSNN_dropoutModuleApply(tensor.Handle, _probability, _isTraining));
26+
return new TorchTensor(THSNN_dropoutModuleApply(tensor.Handle, _probability, _isTraining));
2727
}
2828
}
2929
}

TorchSharp/NN/FeatureDropout.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ internal FeatureDropout() : base()
1616
[DllImport("libTorchSharp")]
1717
extern static IntPtr THSNN_featureDropoutApply(IntPtr tensor);
1818

19-
public override ITorchTensor<float> Forward<T>(ITorchTensor<T> tensor)
19+
public override ITorchTensor Forward(ITorchTensor tensor)
2020
{
21-
return new FloatTensor(THSNN_featureDropoutApply(tensor.Handle));
21+
return new TorchTensor(THSNN_featureDropoutApply(tensor.Handle));
2222
}
2323
}
2424
}

TorchSharp/NN/FunctionalModule.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ public override void ZeroGrad()
2121
{
2222
}
2323

24-
public override IEnumerable<(string name, ITorchTensor<float> parameter)> NamedParameters()
24+
public override IEnumerable<(string name, ITorchTensor parameter)> NamedParameters()
2525
{
26-
return new List<(string, ITorchTensor<float>)>();
26+
return new List<(string, ITorchTensor)>();
2727
}
2828

29-
public override IEnumerable<ITorchTensor<float>> Parameters()
29+
public override IEnumerable<ITorchTensor> Parameters()
3030
{
31-
return new List<ITorchTensor<float>>();
31+
return new List<ITorchTensor>();
3232
}
3333

3434
public override IEnumerable<string> GetModules()

TorchSharp/NN/Init.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@ public static class Init
99
[DllImport("libTorchSharp")]
1010
extern static void THSNN_initUniform(IntPtr src, double low, double high);
1111

12-
public static void Uniform<T>(ITorchTensor<T> tensor, double low = 0, double high = 1)
12+
public static void Uniform(ITorchTensor tensor, double low = 0, double high = 1)
1313
{
1414
THSNN_initUniform(tensor.Handle, low, high);
1515
}
1616

1717
[DllImport("libTorchSharp")]
1818
extern static void THSNN_initKaimingUniform(IntPtr src, double a);
1919

20-
public static void KaimingUniform<T>(ITorchTensor<T> tensor, double a = 0)
20+
public static void KaimingUniform(ITorchTensor tensor, double a = 0)
2121
{
2222
THSNN_initKaimingUniform(tensor.Handle, a);
2323
}
2424

25-
public static (long fanIn, long fanOut) CalculateFanInAndFanOut<T>(ITorchTensor<T> tensor)
25+
public static (long fanIn, long fanOut) CalculateFanInAndFanOut<T>(ITorchTensor tensor)
2626
{
2727
var dimensions = tensor.Dimensions;
2828

@@ -35,7 +35,7 @@ public static (long fanIn, long fanOut) CalculateFanInAndFanOut<T>(ITorchTensor<
3535
// Linear
3636
if (dimensions == 2)
3737
{
38-
return (shape[1], shape[2]);
38+
return (shape[1], shape[0]);
3939
}
4040
else
4141
{

0 commit comments

Comments
 (0)