|
7 | 7 | using System.Runtime.InteropServices;
|
8 | 8 | using System.Text;
|
9 | 9 |
|
| 10 | +[assembly: InternalsVisibleTo("TorchSharp")] |
| 11 | + |
10 | 12 | namespace TorchSharp.Tensor {
|
11 | 13 |
|
12 | 14 |
|
@@ -67,27 +69,27 @@ public static ByteTensor From(byte scalar)
|
67 | 69 | [DllImport("LibTorchSharp")]
|
68 | 70 | extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
|
69 | 71 |
|
70 |
| - public static ByteTensor From(IntPtr rawArray, long[] dimensions, long[] strides) |
| 72 | + public static ByteTensor From(IntPtr rawArray, long[] dimensions) |
71 | 73 | {
|
72 |
| - if (dimensions.Length != strides.Length) |
| 74 | + var length = dimensions.Length; |
| 75 | + var strides = new long[length]; |
| 76 | + |
| 77 | + strides[0] = 1; |
| 78 | + for (int i = 1; i < length; i++) |
73 | 79 | {
|
74 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
| 80 | + strides[i] = dimensions[i - 1]; |
75 | 81 | }
|
76 | 82 |
|
77 | 83 | return new ByteTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Byte));
|
78 | 84 | }
|
79 | 85 |
|
80 |
| - public static ByteTensor From(byte[] rawArray, long[] dimensions, long[] strides) |
| 86 | + public static ByteTensor From(byte[] rawArray, long[] dimensions) |
81 | 87 | {
|
82 |
| - if (dimensions.Length != strides.Length) |
83 |
| - { |
84 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
85 |
| - } |
86 | 88 | unsafe
|
87 | 89 | {
|
88 | 90 | fixed (byte* parray = rawArray)
|
89 | 91 | {
|
90 |
| - return new ByteTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Byte)); |
| 92 | + return ByteTensor.From((IntPtr)parray, dimensions); |
91 | 93 | }
|
92 | 94 | }
|
93 | 95 | }
|
@@ -402,27 +404,27 @@ public static ShortTensor From(short scalar)
|
402 | 404 | [DllImport("LibTorchSharp")]
|
403 | 405 | extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
|
404 | 406 |
|
405 |
| - public static ShortTensor From(IntPtr rawArray, long[] dimensions, long[] strides) |
| 407 | + public static ShortTensor From(IntPtr rawArray, long[] dimensions) |
406 | 408 | {
|
407 |
| - if (dimensions.Length != strides.Length) |
| 409 | + var length = dimensions.Length; |
| 410 | + var strides = new long[length]; |
| 411 | + |
| 412 | + strides[0] = 1; |
| 413 | + for (int i = 1; i < length; i++) |
408 | 414 | {
|
409 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
| 415 | + strides[i] = dimensions[i - 1]; |
410 | 416 | }
|
411 | 417 |
|
412 | 418 | return new ShortTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Short));
|
413 | 419 | }
|
414 | 420 |
|
415 |
| - public static ShortTensor From(short[] rawArray, long[] dimensions, long[] strides) |
| 421 | + public static ShortTensor From(short[] rawArray, long[] dimensions) |
416 | 422 | {
|
417 |
| - if (dimensions.Length != strides.Length) |
418 |
| - { |
419 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
420 |
| - } |
421 | 423 | unsafe
|
422 | 424 | {
|
423 | 425 | fixed (short* parray = rawArray)
|
424 | 426 | {
|
425 |
| - return new ShortTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Short)); |
| 427 | + return ShortTensor.From((IntPtr)parray, dimensions); |
426 | 428 | }
|
427 | 429 | }
|
428 | 430 | }
|
@@ -737,27 +739,27 @@ public static IntTensor From(int scalar)
|
737 | 739 | [DllImport("LibTorchSharp")]
|
738 | 740 | extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
|
739 | 741 |
|
740 |
| - public static IntTensor From(IntPtr rawArray, long[] dimensions, long[] strides) |
| 742 | + public static IntTensor From(IntPtr rawArray, long[] dimensions) |
741 | 743 | {
|
742 |
| - if (dimensions.Length != strides.Length) |
| 744 | + var length = dimensions.Length; |
| 745 | + var strides = new long[length]; |
| 746 | + |
| 747 | + strides[0] = 1; |
| 748 | + for (int i = 1; i < length; i++) |
743 | 749 | {
|
744 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
| 750 | + strides[i] = dimensions[i - 1]; |
745 | 751 | }
|
746 | 752 |
|
747 | 753 | return new IntTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Int));
|
748 | 754 | }
|
749 | 755 |
|
750 |
| - public static IntTensor From(int[] rawArray, long[] dimensions, long[] strides) |
| 756 | + public static IntTensor From(int[] rawArray, long[] dimensions) |
751 | 757 | {
|
752 |
| - if (dimensions.Length != strides.Length) |
753 |
| - { |
754 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
755 |
| - } |
756 | 758 | unsafe
|
757 | 759 | {
|
758 | 760 | fixed (int* parray = rawArray)
|
759 | 761 | {
|
760 |
| - return new IntTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Int)); |
| 762 | + return IntTensor.From((IntPtr)parray, dimensions); |
761 | 763 | }
|
762 | 764 | }
|
763 | 765 | }
|
@@ -1072,27 +1074,27 @@ public static LongTensor From(long scalar)
|
1072 | 1074 | [DllImport("LibTorchSharp")]
|
1073 | 1075 | extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
|
1074 | 1076 |
|
1075 |
| - public static LongTensor From(IntPtr rawArray, long[] dimensions, long[] strides) |
| 1077 | + public static LongTensor From(IntPtr rawArray, long[] dimensions) |
1076 | 1078 | {
|
1077 |
| - if (dimensions.Length != strides.Length) |
| 1079 | + var length = dimensions.Length; |
| 1080 | + var strides = new long[length]; |
| 1081 | + |
| 1082 | + strides[0] = 1; |
| 1083 | + for (int i = 1; i < length; i++) |
1078 | 1084 | {
|
1079 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
| 1085 | + strides[i] = dimensions[i - 1]; |
1080 | 1086 | }
|
1081 | 1087 |
|
1082 | 1088 | return new LongTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Long));
|
1083 | 1089 | }
|
1084 | 1090 |
|
1085 |
| - public static LongTensor From(long[] rawArray, long[] dimensions, long[] strides) |
| 1091 | + public static LongTensor From(long[] rawArray, long[] dimensions) |
1086 | 1092 | {
|
1087 |
| - if (dimensions.Length != strides.Length) |
1088 |
| - { |
1089 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
1090 |
| - } |
1091 | 1093 | unsafe
|
1092 | 1094 | {
|
1093 | 1095 | fixed (long* parray = rawArray)
|
1094 | 1096 | {
|
1095 |
| - return new LongTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Long)); |
| 1097 | + return LongTensor.From((IntPtr)parray, dimensions); |
1096 | 1098 | }
|
1097 | 1099 | }
|
1098 | 1100 | }
|
@@ -1407,27 +1409,27 @@ public static DoubleTensor From(double scalar)
|
1407 | 1409 | [DllImport("LibTorchSharp")]
|
1408 | 1410 | extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
|
1409 | 1411 |
|
1410 |
| - public static DoubleTensor From(IntPtr rawArray, long[] dimensions, long[] strides) |
| 1412 | + public static DoubleTensor From(IntPtr rawArray, long[] dimensions) |
1411 | 1413 | {
|
1412 |
| - if (dimensions.Length != strides.Length) |
| 1414 | + var length = dimensions.Length; |
| 1415 | + var strides = new long[length]; |
| 1416 | + |
| 1417 | + strides[0] = 1; |
| 1418 | + for (int i = 1; i < length; i++) |
1413 | 1419 | {
|
1414 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
| 1420 | + strides[i] = dimensions[i - 1]; |
1415 | 1421 | }
|
1416 | 1422 |
|
1417 | 1423 | return new DoubleTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Double));
|
1418 | 1424 | }
|
1419 | 1425 |
|
1420 |
| - public static DoubleTensor From(double[] rawArray, long[] dimensions, long[] strides) |
| 1426 | + public static DoubleTensor From(double[] rawArray, long[] dimensions) |
1421 | 1427 | {
|
1422 |
| - if (dimensions.Length != strides.Length) |
1423 |
| - { |
1424 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
1425 |
| - } |
1426 | 1428 | unsafe
|
1427 | 1429 | {
|
1428 | 1430 | fixed (double* parray = rawArray)
|
1429 | 1431 | {
|
1430 |
| - return new DoubleTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Double)); |
| 1432 | + return DoubleTensor.From((IntPtr)parray, dimensions); |
1431 | 1433 | }
|
1432 | 1434 | }
|
1433 | 1435 | }
|
@@ -1742,27 +1744,27 @@ public static FloatTensor From(float scalar)
|
1742 | 1744 | [DllImport("LibTorchSharp")]
|
1743 | 1745 | extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
|
1744 | 1746 |
|
1745 |
| - public static FloatTensor From(IntPtr rawArray, long[] dimensions, long[] strides) |
| 1747 | + public static FloatTensor From(IntPtr rawArray, long[] dimensions) |
1746 | 1748 | {
|
1747 |
| - if (dimensions.Length != strides.Length) |
| 1749 | + var length = dimensions.Length; |
| 1750 | + var strides = new long[length]; |
| 1751 | + |
| 1752 | + strides[0] = 1; |
| 1753 | + for (int i = 1; i < length; i++) |
1748 | 1754 | {
|
1749 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
| 1755 | + strides[i] = dimensions[i - 1]; |
1750 | 1756 | }
|
1751 | 1757 |
|
1752 | 1758 | return new FloatTensor(THS_new(rawArray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Float));
|
1753 | 1759 | }
|
1754 | 1760 |
|
1755 |
| - public static FloatTensor From(float[] rawArray, long[] dimensions, long[] strides) |
| 1761 | + public static FloatTensor From(float[] rawArray, long[] dimensions) |
1756 | 1762 | {
|
1757 |
| - if (dimensions.Length != strides.Length) |
1758 |
| - { |
1759 |
| - throw new ArgumentException("Dimensions and strides do not match."); |
1760 |
| - } |
1761 | 1763 | unsafe
|
1762 | 1764 | {
|
1763 | 1765 | fixed (float* parray = rawArray)
|
1764 | 1766 | {
|
1765 |
| - return new FloatTensor(THS_new((IntPtr)parray, dimensions, dimensions.Length, strides, strides.Length, (sbyte)ATenScalarMapping.Float)); |
| 1767 | + return FloatTensor.From((IntPtr)parray, dimensions); |
1766 | 1768 | }
|
1767 | 1769 | }
|
1768 | 1770 | }
|
@@ -2072,39 +2074,78 @@ internal static ITorchTensor<T> ToTorchTensor<T>(this IntPtr rawTensor)
|
2072 | 2074 | }
|
2073 | 2075 | }
|
2074 | 2076 |
|
2075 |
| - public static ITorchTensor<T> ToTorchTensor<T>(this T[] rawArray) |
| 2077 | + public static ITorchTensor<T> ToTorchTensor<T>(this T[] rawArray, long[] dimensions) |
| 2078 | + { |
| 2079 | + switch (true) |
| 2080 | + { |
| 2081 | + |
| 2082 | + case bool _ when typeof(T) == typeof(byte): |
| 2083 | + { |
| 2084 | + return ByteTensor.From(rawArray as byte[], dimensions) as ITorchTensor<T>; |
| 2085 | + } |
| 2086 | + |
| 2087 | + case bool _ when typeof(T) == typeof(short): |
| 2088 | + { |
| 2089 | + return ShortTensor.From(rawArray as short[], dimensions) as ITorchTensor<T>; |
| 2090 | + } |
| 2091 | + |
| 2092 | + case bool _ when typeof(T) == typeof(int): |
| 2093 | + { |
| 2094 | + return IntTensor.From(rawArray as int[], dimensions) as ITorchTensor<T>; |
| 2095 | + } |
| 2096 | + |
| 2097 | + case bool _ when typeof(T) == typeof(long): |
| 2098 | + { |
| 2099 | + return LongTensor.From(rawArray as long[], dimensions) as ITorchTensor<T>; |
| 2100 | + } |
| 2101 | + |
| 2102 | + case bool _ when typeof(T) == typeof(double): |
| 2103 | + { |
| 2104 | + return DoubleTensor.From(rawArray as double[], dimensions) as ITorchTensor<T>; |
| 2105 | + } |
| 2106 | + |
| 2107 | + case bool _ when typeof(T) == typeof(float): |
| 2108 | + { |
| 2109 | + return FloatTensor.From(rawArray as float[], dimensions) as ITorchTensor<T>; |
| 2110 | + } |
| 2111 | + |
| 2112 | + default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported."); |
| 2113 | + } |
| 2114 | + } |
| 2115 | + |
| 2116 | + public static ITorchTensor<T> ToTorchTensor<T>(this T scalar) |
2076 | 2117 | {
|
2077 | 2118 | switch (true)
|
2078 | 2119 | {
|
2079 | 2120 |
|
2080 | 2121 | case bool _ when typeof(T) == typeof(byte):
|
2081 | 2122 | {
|
2082 |
| - return ByteTensor.From(rawArray as byte[], new long[]{}, new long[]{}) as ITorchTensor<T>; |
| 2123 | + return ByteTensor.From((byte)(object)scalar) as ITorchTensor<T>; |
2083 | 2124 | }
|
2084 | 2125 |
|
2085 | 2126 | case bool _ when typeof(T) == typeof(short):
|
2086 | 2127 | {
|
2087 |
| - return ShortTensor.From(rawArray as short[], new long[]{}, new long[]{}) as ITorchTensor<T>; |
| 2128 | + return ShortTensor.From((short)(object)scalar) as ITorchTensor<T>; |
2088 | 2129 | }
|
2089 | 2130 |
|
2090 | 2131 | case bool _ when typeof(T) == typeof(int):
|
2091 | 2132 | {
|
2092 |
| - return IntTensor.From(rawArray as int[], new long[]{}, new long[]{}) as ITorchTensor<T>; |
| 2133 | + return IntTensor.From((int)(object)scalar) as ITorchTensor<T>; |
2093 | 2134 | }
|
2094 | 2135 |
|
2095 | 2136 | case bool _ when typeof(T) == typeof(long):
|
2096 | 2137 | {
|
2097 |
| - return LongTensor.From(rawArray as long[], new long[]{}, new long[]{}) as ITorchTensor<T>; |
| 2138 | + return LongTensor.From((long)(object)scalar) as ITorchTensor<T>; |
2098 | 2139 | }
|
2099 | 2140 |
|
2100 | 2141 | case bool _ when typeof(T) == typeof(double):
|
2101 | 2142 | {
|
2102 |
| - return DoubleTensor.From(rawArray as double[], new long[]{}, new long[]{}) as ITorchTensor<T>; |
| 2143 | + return DoubleTensor.From((double)(object)scalar) as ITorchTensor<T>; |
2103 | 2144 | }
|
2104 | 2145 |
|
2105 | 2146 | case bool _ when typeof(T) == typeof(float):
|
2106 | 2147 | {
|
2107 |
| - return FloatTensor.From(rawArray as float[], new long[]{}, new long[]{}) as ITorchTensor<T>; |
| 2148 | + return FloatTensor.From((float)(object)scalar) as ITorchTensor<T>; |
2108 | 2149 | }
|
2109 | 2150 |
|
2110 | 2151 | default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
|
|
0 commit comments