Skip to content

Commit 4e35be3

Browse files
committed
Some update to make ToTorchTensor work properly.
1 parent a44be16 commit 4e35be3

File tree

3 files changed

+147
-75
lines changed

3 files changed

+147
-75
lines changed

Test/TorchSharp.cs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
23
using TorchSharp.JIT;
34
using TorchSharp.Tensor;
45

@@ -68,7 +69,7 @@ public void CreateFloatTensorFromData()
6869
var data = new float[1000];
6970
data[100] = 1;
7071

71-
using (var tensor = FloatTensor.From(data, new long[] { 100, 10 }, new long[] { 1, 100 }))
72+
using (var tensor = FloatTensor.From(data, new long[] { 100, 10 }))
7273
{
7374
Assert.AreEqual(tensor.Data[100], 1);
7475
}
@@ -77,11 +78,10 @@ public void CreateFloatTensorFromData()
7778
[TestMethod]
7879
public void CreateFloatTensorFromData2()
7980
{
80-
8181
var data = new float[1000];
8282
data[100] = 1;
8383

84-
using (var tensor = data.ToTorchTensor())
84+
using (var tensor = data.ToTorchTensor(new long[] { 10, 100 }))
8585
{
8686
Assert.AreEqual(tensor.Data[100], 1);
8787
}
@@ -98,6 +98,17 @@ public void CreateFloatTensorFromScalar()
9898
}
9999
}
100100

101+
[TestMethod]
102+
public void CreateFloatTensorFromScalar2()
103+
{
104+
float scalar = 333.0f;
105+
106+
using (var tensor = scalar.ToTorchTensor())
107+
{
108+
Assert.AreEqual(tensor.Item, 333);
109+
}
110+
}
111+
101112
[TestMethod]
102113
public void ScoreModel()
103114
{

TorchSharp/Generated/TorchTensor.generated.cs

Lines changed: 102 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
using System.Runtime.InteropServices;
88
using System.Text;
99

10+
[assembly: InternalsVisibleTo("TorchSharp")]
11+
1012
namespace TorchSharp.Tensor {
1113

1214

@@ -67,27 +69,27 @@ public static ByteTensor From(byte scalar)
6769
[DllImport("LibTorchSharp")]
6870
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
6971

70-
public static ByteTensor From(IntPtr rawArray, long[] dimensions, long[] strides)
72+
public static ByteTensor From(IntPtr rawArray, long[] dimensions)
7173
{
72-
if (dimensions.Length != strides.Length)
74+
var length = dimensions.Length;
75+
var strides = new long[length];
76+
77+
strides[0] = 1;
78+
for (int i = 1; i < length; i++)
7379
{
74-
throw new ArgumentException("Dimensions and strides do not match.");
80+
strides[i] = dimensions[i - 1];
7581
}
7682

7783
return new ByteTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Byte));
7884
}
7985

80-
public static ByteTensor From(byte[] rawArray, long[] dimensions, long[] strides)
86+
public static ByteTensor From(byte[] rawArray, long[] dimensions)
8187
{
82-
if (dimensions.Length != strides.Length)
83-
{
84-
throw new ArgumentException("Dimensions and strides do not match.");
85-
}
8688
unsafe
8789
{
8890
fixed (byte* parray = rawArray)
8991
{
90-
return new ByteTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Byte));
92+
return ByteTensor.From((IntPtr)parray, dimensions);
9193
}
9294
}
9395
}
@@ -402,27 +404,27 @@ public static ShortTensor From(short scalar)
402404
[DllImport("LibTorchSharp")]
403405
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
404406

405-
public static ShortTensor From(IntPtr rawArray, long[] dimensions, long[] strides)
407+
public static ShortTensor From(IntPtr rawArray, long[] dimensions)
406408
{
407-
if (dimensions.Length != strides.Length)
409+
var length = dimensions.Length;
410+
var strides = new long[length];
411+
412+
strides[0] = 1;
413+
for (int i = 1; i < length; i++)
408414
{
409-
throw new ArgumentException("Dimensions and strides do not match.");
415+
strides[i] = dimensions[i - 1];
410416
}
411417

412418
return new ShortTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Short));
413419
}
414420

415-
public static ShortTensor From(short[] rawArray, long[] dimensions, long[] strides)
421+
public static ShortTensor From(short[] rawArray, long[] dimensions)
416422
{
417-
if (dimensions.Length != strides.Length)
418-
{
419-
throw new ArgumentException("Dimensions and strides do not match.");
420-
}
421423
unsafe
422424
{
423425
fixed (short* parray = rawArray)
424426
{
425-
return new ShortTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Short));
427+
return ShortTensor.From((IntPtr)parray, dimensions);
426428
}
427429
}
428430
}
@@ -737,27 +739,27 @@ public static IntTensor From(int scalar)
737739
[DllImport("LibTorchSharp")]
738740
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
739741

740-
public static IntTensor From(IntPtr rawArray, long[] dimensions, long[] strides)
742+
public static IntTensor From(IntPtr rawArray, long[] dimensions)
741743
{
742-
if (dimensions.Length != strides.Length)
744+
var length = dimensions.Length;
745+
var strides = new long[length];
746+
747+
strides[0] = 1;
748+
for (int i = 1; i < length; i++)
743749
{
744-
throw new ArgumentException("Dimensions and strides do not match.");
750+
strides[i] = dimensions[i - 1];
745751
}
746752

747753
return new IntTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Int));
748754
}
749755

750-
public static IntTensor From(int[] rawArray, long[] dimensions, long[] strides)
756+
public static IntTensor From(int[] rawArray, long[] dimensions)
751757
{
752-
if (dimensions.Length != strides.Length)
753-
{
754-
throw new ArgumentException("Dimensions and strides do not match.");
755-
}
756758
unsafe
757759
{
758760
fixed (int* parray = rawArray)
759761
{
760-
return new IntTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Int));
762+
return IntTensor.From((IntPtr)parray, dimensions);
761763
}
762764
}
763765
}
@@ -1072,27 +1074,27 @@ public static LongTensor From(long scalar)
10721074
[DllImport("LibTorchSharp")]
10731075
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
10741076

1075-
public static LongTensor From(IntPtr rawArray, long[] dimensions, long[] strides)
1077+
public static LongTensor From(IntPtr rawArray, long[] dimensions)
10761078
{
1077-
if (dimensions.Length != strides.Length)
1079+
var length = dimensions.Length;
1080+
var strides = new long[length];
1081+
1082+
strides[0] = 1;
1083+
for (int i = 1; i < length; i++)
10781084
{
1079-
throw new ArgumentException("Dimensions and strides do not match.");
1085+
strides[i] = dimensions[i - 1];
10801086
}
10811087

10821088
return new LongTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Long));
10831089
}
10841090

1085-
public static LongTensor From(long[] rawArray, long[] dimensions, long[] strides)
1091+
public static LongTensor From(long[] rawArray, long[] dimensions)
10861092
{
1087-
if (dimensions.Length != strides.Length)
1088-
{
1089-
throw new ArgumentException("Dimensions and strides do not match.");
1090-
}
10911093
unsafe
10921094
{
10931095
fixed (long* parray = rawArray)
10941096
{
1095-
return new LongTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Long));
1097+
return LongTensor.From((IntPtr)parray, dimensions);
10961098
}
10971099
}
10981100
}
@@ -1407,27 +1409,27 @@ public static DoubleTensor From(double scalar)
14071409
[DllImport("LibTorchSharp")]
14081410
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
14091411

1410-
public static DoubleTensor From(IntPtr rawArray, long[] dimensions, long[] strides)
1412+
public static DoubleTensor From(IntPtr rawArray, long[] dimensions)
14111413
{
1412-
if (dimensions.Length != strides.Length)
1414+
var length = dimensions.Length;
1415+
var strides = new long[length];
1416+
1417+
strides[0] = 1;
1418+
for (int i = 1; i < length; i++)
14131419
{
1414-
throw new ArgumentException("Dimensions and strides do not match.");
1420+
strides[i] = dimensions[i - 1];
14151421
}
14161422

14171423
return new DoubleTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Double));
14181424
}
14191425

1420-
public static DoubleTensor From(double[] rawArray, long[] dimensions, long[] strides)
1426+
public static DoubleTensor From(double[] rawArray, long[] dimensions)
14211427
{
1422-
if (dimensions.Length != strides.Length)
1423-
{
1424-
throw new ArgumentException("Dimensions and strides do not match.");
1425-
}
14261428
unsafe
14271429
{
14281430
fixed (double* parray = rawArray)
14291431
{
1430-
return new DoubleTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Double));
1432+
return DoubleTensor.From((IntPtr)parray, dimensions);
14311433
}
14321434
}
14331435
}
@@ -1742,27 +1744,27 @@ public static FloatTensor From(float scalar)
17421744
[DllImport("LibTorchSharp")]
17431745
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
17441746

1745-
public static FloatTensor From(IntPtr rawArray, long[] dimensions, long[] strides)
1747+
public static FloatTensor From(IntPtr rawArray, long[] dimensions)
17461748
{
1747-
if (dimensions.Length != strides.Length)
1749+
var length = dimensions.Length;
1750+
var strides = new long[length];
1751+
1752+
strides[0] = 1;
1753+
for (int i = 1; i < length; i++)
17481754
{
1749-
throw new ArgumentException("Dimensions and strides do not match.");
1755+
strides[i] = dimensions[i - 1];
17501756
}
17511757

17521758
return new FloatTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Float));
17531759
}
17541760

1755-
public static FloatTensor From(float[] rawArray, long[] dimensions, long[] strides)
1761+
public static FloatTensor From(float[] rawArray, long[] dimensions)
17561762
{
1757-
if (dimensions.Length != strides.Length)
1758-
{
1759-
throw new ArgumentException("Dimensions and strides do not match.");
1760-
}
17611763
unsafe
17621764
{
17631765
fixed (float* parray = rawArray)
17641766
{
1765-
return new FloatTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Float));
1767+
return FloatTensor.From((IntPtr)parray, dimensions);
17661768
}
17671769
}
17681770
}
@@ -2072,39 +2074,78 @@ internal static ITorchTensor<T> ToTorchTensor<T>(this IntPtr rawTensor)
20722074
}
20732075
}
20742076

2075-
public static ITorchTensor<T> ToTorchTensor<T>(this T[] rawArray)
2077+
public static ITorchTensor<T> ToTorchTensor<T>(this T[] rawArray, long[] dimensions)
2078+
{
2079+
switch (true)
2080+
{
2081+
2082+
case bool _ when typeof(T) == typeof(byte):
2083+
{
2084+
return ByteTensor.From(rawArray as byte[], dimensions) as ITorchTensor<T>;
2085+
}
2086+
2087+
case bool _ when typeof(T) == typeof(short):
2088+
{
2089+
return ShortTensor.From(rawArray as short[], dimensions) as ITorchTensor<T>;
2090+
}
2091+
2092+
case bool _ when typeof(T) == typeof(int):
2093+
{
2094+
return IntTensor.From(rawArray as int[], dimensions) as ITorchTensor<T>;
2095+
}
2096+
2097+
case bool _ when typeof(T) == typeof(long):
2098+
{
2099+
return LongTensor.From(rawArray as long[], dimensions) as ITorchTensor<T>;
2100+
}
2101+
2102+
case bool _ when typeof(T) == typeof(double):
2103+
{
2104+
return DoubleTensor.From(rawArray as double[], dimensions) as ITorchTensor<T>;
2105+
}
2106+
2107+
case bool _ when typeof(T) == typeof(float):
2108+
{
2109+
return FloatTensor.From(rawArray as float[], dimensions) as ITorchTensor<T>;
2110+
}
2111+
2112+
default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
2113+
}
2114+
}
2115+
2116+
public static ITorchTensor<T> ToTorchTensor<T>(this T scalar)
20762117
{
20772118
switch (true)
20782119
{
20792120

20802121
case bool _ when typeof(T) == typeof(byte):
20812122
{
2082-
return ByteTensor.From(rawArray as byte[], new long[]{}, new long[]{}) as ITorchTensor<T>;
2123+
return ByteTensor.From((byte)(object)scalar) as ITorchTensor<T>;
20832124
}
20842125

20852126
case bool _ when typeof(T) == typeof(short):
20862127
{
2087-
return ShortTensor.From(rawArray as short[], new long[]{}, new long[]{}) as ITorchTensor<T>;
2128+
return ShortTensor.From((short)(object)scalar) as ITorchTensor<T>;
20882129
}
20892130

20902131
case bool _ when typeof(T) == typeof(int):
20912132
{
2092-
return IntTensor.From(rawArray as int[], new long[]{}, new long[]{}) as ITorchTensor<T>;
2133+
return IntTensor.From((int)(object)scalar) as ITorchTensor<T>;
20932134
}
20942135

20952136
case bool _ when typeof(T) == typeof(long):
20962137
{
2097-
return LongTensor.From(rawArray as long[], new long[]{}, new long[]{}) as ITorchTensor<T>;
2138+
return LongTensor.From((long)(object)scalar) as ITorchTensor<T>;
20982139
}
20992140

21002141
case bool _ when typeof(T) == typeof(double):
21012142
{
2102-
return DoubleTensor.From(rawArray as double[], new long[]{}, new long[]{}) as ITorchTensor<T>;
2143+
return DoubleTensor.From((double)(object)scalar) as ITorchTensor<T>;
21032144
}
21042145

21052146
case bool _ when typeof(T) == typeof(float):
21062147
{
2107-
return FloatTensor.From(rawArray as float[], new long[]{}, new long[]{}) as ITorchTensor<T>;
2148+
return FloatTensor.From((float)(object)scalar) as ITorchTensor<T>;
21082149
}
21092150

21102151
default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");

0 commit comments

Comments
 (0)