Skip to content

Commit d205a83

Browse files
committed
Fixed few things plus added method to generate tensors from arrays.
1 parent e34a760 commit d205a83

File tree

5 files changed

+368
-32
lines changed

5 files changed

+368
-32
lines changed

Test/AtenSharp/BasicTensorAPI.cs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,6 @@ public void CreateFloatTensor()
2424
Assert.AreEqual(20, x2Shape[1]);
2525
}
2626

27-
//[TestMethod]
28-
//public void CreateIntTensorOne()
29-
//{
30-
// var x1 = FloatTensor.Ones(new long[] { 1, 3, 224, 224 });
31-
32-
// //Assert.AreEqual(4, x1.Shape.Length);
33-
34-
// var module = Module.LoadModule(@"E:\Source\Repos\libtorch\model.pt");
35-
36-
// var modules = module.GetModules();
37-
// var result = module.Score(x1);
38-
//}
39-
4027
[TestMethod]
4128
public void GetFloatTensorData()
4229
{

Test/TorchSharp.cs

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,6 @@ public void CreateFloatTensorOnes()
1515
Assert.IsNotNull(ones);
1616
}
1717

18-
[TestMethod]
19-
public void CreateFloatTensorCheckDistructor()
20-
{
21-
ITorchTensor<float> ones = null;
22-
23-
using (var tmp = FloatTensor.Ones(new long[] { 2, 2 }))
24-
{
25-
ones = tmp;
26-
Assert.IsNotNull(ones);
27-
}
28-
Assert.ThrowsException<ObjectDisposedException>(ones.Grad);
29-
}
30-
3118
[TestMethod]
3219
public void CreateFloatTensorCheckMemory()
3320
{
@@ -76,6 +63,18 @@ public void CreateFloatTensorCheckDevice()
7663
Assert.AreEqual(ones.Device, "cpu");
7764
}
7865

66+
[TestMethod]
67+
public void CreateFloatTensorFromData()
68+
{
69+
var data = new float[1000];
70+
data[100] = 1;
71+
72+
using (var tensor = FloatTensor.From(data, new long[] { 100, 10 }, new long[] { 1, 100 }))
73+
{
74+
Assert.AreEqual(tensor.Data[100], 1);
75+
}
76+
}
77+
7978
[TestMethod]
8079
public void ScoreModel()
8180
{
@@ -90,7 +89,7 @@ public void ScoreModel()
9089
}
9190

9291
[TestMethod]
93-
public void ScoreModelCheckInput()
92+
public void LoadModelCheckInput()
9493
{
9594
var module = JIT.Module.Load(@"E:\Source\Repos\libtorch\model.pt");
9695
Assert.IsNotNull(module);
@@ -105,6 +104,22 @@ public void ScoreModelCheckInput()
105104
}
106105
}
107106

107+
[TestMethod]
108+
public void LoadModelCheckOutput()
109+
{
110+
var module = JIT.Module.Load(@"E:\Source\Repos\libtorch\model.pt");
111+
Assert.IsNotNull(module);
112+
113+
var num = module.GetNumberOfOutputs();
114+
115+
for (int i = 0; i < num; i++)
116+
{
117+
var type = module.GetOutputType(i);
118+
119+
Assert.IsNotNull(type as DynamicType);
120+
}
121+
}
122+
108123
[TestMethod]
109124
public void ScoreModelCheckOutput()
110125
{

0 commit comments

Comments
 (0)