@@ -56,6 +56,14 @@ public IntPtr Handle
56
56
}
57
57
}
58
58
59
+ [ DllImport ( "LibTorchSharp" ) ]
60
+ extern static IntPtr THS_new_byteScalar ( byte scalar ) ;
61
+
62
+ public static ByteTensor From ( byte scalar )
63
+ {
64
+ return new ByteTensor ( THS_new_byteScalar ( scalar ) ) ;
65
+ }
66
+
59
67
[ DllImport ( "LibTorchSharp" ) ]
60
68
extern static IntPtr THS_new ( IntPtr rawArray , long [ ] dimensions , int numDimensions , long [ ] strides , int numStrides , sbyte type ) ;
61
69
@@ -288,7 +296,7 @@ public ITorchTensor<byte> Sum()
288
296
289
297
public ITorchTensor < U > Eq < U > ( ITorchTensor < U > target )
290
298
{
291
- return TensorExtensionMethods . ToTorchTensor < U > ( THS_Eq ( handle , target . Handle ) ) ;
299
+ return THS_Eq ( handle , target . Handle ) . ToTorchTensor < U > ( ) ;
292
300
}
293
301
294
302
[ DllImport ( "LibTorchSharp" ) ]
@@ -383,6 +391,14 @@ public IntPtr Handle
383
391
}
384
392
}
385
393
394
+ [ DllImport ( "LibTorchSharp" ) ]
395
+ extern static IntPtr THS_new_shortScalar ( short scalar ) ;
396
+
397
+ public static ShortTensor From ( short scalar )
398
+ {
399
+ return new ShortTensor ( THS_new_shortScalar ( scalar ) ) ;
400
+ }
401
+
386
402
[ DllImport ( "LibTorchSharp" ) ]
387
403
extern static IntPtr THS_new ( IntPtr rawArray , long [ ] dimensions , int numDimensions , long [ ] strides , int numStrides , sbyte type ) ;
388
404
@@ -615,7 +631,7 @@ public ITorchTensor<short> Sum()
615
631
616
632
public ITorchTensor < U > Eq < U > ( ITorchTensor < U > target )
617
633
{
618
- return TensorExtensionMethods . ToTorchTensor < U > ( THS_Eq ( handle , target . Handle ) ) ;
634
+ return THS_Eq ( handle , target . Handle ) . ToTorchTensor < U > ( ) ;
619
635
}
620
636
621
637
[ DllImport ( "LibTorchSharp" ) ]
@@ -710,6 +726,14 @@ public IntPtr Handle
710
726
}
711
727
}
712
728
729
+ [ DllImport ( "LibTorchSharp" ) ]
730
+ extern static IntPtr THS_new_intScalar ( int scalar ) ;
731
+
732
+ public static IntTensor From ( int scalar )
733
+ {
734
+ return new IntTensor ( THS_new_intScalar ( scalar ) ) ;
735
+ }
736
+
713
737
[ DllImport ( "LibTorchSharp" ) ]
714
738
extern static IntPtr THS_new ( IntPtr rawArray , long [ ] dimensions , int numDimensions , long [ ] strides , int numStrides , sbyte type ) ;
715
739
@@ -942,7 +966,7 @@ public ITorchTensor<int> Sum()
942
966
943
967
public ITorchTensor < U > Eq < U > ( ITorchTensor < U > target )
944
968
{
945
- return TensorExtensionMethods . ToTorchTensor < U > ( THS_Eq ( handle , target . Handle ) ) ;
969
+ return THS_Eq ( handle , target . Handle ) . ToTorchTensor < U > ( ) ;
946
970
}
947
971
948
972
[ DllImport ( "LibTorchSharp" ) ]
@@ -1037,6 +1061,14 @@ public IntPtr Handle
1037
1061
}
1038
1062
}
1039
1063
1064
+ [ DllImport ( "LibTorchSharp" ) ]
1065
+ extern static IntPtr THS_new_longScalar ( long scalar ) ;
1066
+
1067
+ public static LongTensor From ( long scalar )
1068
+ {
1069
+ return new LongTensor ( THS_new_longScalar ( scalar ) ) ;
1070
+ }
1071
+
1040
1072
[ DllImport ( "LibTorchSharp" ) ]
1041
1073
extern static IntPtr THS_new ( IntPtr rawArray , long [ ] dimensions , int numDimensions , long [ ] strides , int numStrides , sbyte type ) ;
1042
1074
@@ -1269,7 +1301,7 @@ public ITorchTensor<long> Sum()
1269
1301
1270
1302
public ITorchTensor < U > Eq < U > ( ITorchTensor < U > target )
1271
1303
{
1272
- return TensorExtensionMethods . ToTorchTensor < U > ( THS_Eq ( handle , target . Handle ) ) ;
1304
+ return THS_Eq ( handle , target . Handle ) . ToTorchTensor < U > ( ) ;
1273
1305
}
1274
1306
1275
1307
[ DllImport ( "LibTorchSharp" ) ]
@@ -1364,6 +1396,14 @@ public IntPtr Handle
1364
1396
}
1365
1397
}
1366
1398
1399
+ [ DllImport ( "LibTorchSharp" ) ]
1400
+ extern static IntPtr THS_new_doubleScalar ( double scalar ) ;
1401
+
1402
+ public static DoubleTensor From ( double scalar )
1403
+ {
1404
+ return new DoubleTensor ( THS_new_doubleScalar ( scalar ) ) ;
1405
+ }
1406
+
1367
1407
[ DllImport ( "LibTorchSharp" ) ]
1368
1408
extern static IntPtr THS_new ( IntPtr rawArray , long [ ] dimensions , int numDimensions , long [ ] strides , int numStrides , sbyte type ) ;
1369
1409
@@ -1596,7 +1636,7 @@ public ITorchTensor<double> Sum()
1596
1636
1597
1637
public ITorchTensor < U > Eq < U > ( ITorchTensor < U > target )
1598
1638
{
1599
- return TensorExtensionMethods . ToTorchTensor < U > ( THS_Eq ( handle , target . Handle ) ) ;
1639
+ return THS_Eq ( handle , target . Handle ) . ToTorchTensor < U > ( ) ;
1600
1640
}
1601
1641
1602
1642
[ DllImport ( "LibTorchSharp" ) ]
@@ -1691,6 +1731,14 @@ public IntPtr Handle
1691
1731
}
1692
1732
}
1693
1733
1734
+ [ DllImport ( "LibTorchSharp" ) ]
1735
+ extern static IntPtr THS_new_floatScalar ( float scalar ) ;
1736
+
1737
+ public static FloatTensor From ( float scalar )
1738
+ {
1739
+ return new FloatTensor ( THS_new_floatScalar ( scalar ) ) ;
1740
+ }
1741
+
1694
1742
[ DllImport ( "LibTorchSharp" ) ]
1695
1743
extern static IntPtr THS_new ( IntPtr rawArray , long [ ] dimensions , int numDimensions , long [ ] strides , int numStrides , sbyte type ) ;
1696
1744
@@ -1923,7 +1971,7 @@ public ITorchTensor<float> Sum()
1923
1971
1924
1972
public ITorchTensor < U > Eq < U > ( ITorchTensor < U > target )
1925
1973
{
1926
- return TensorExtensionMethods . ToTorchTensor < U > ( THS_Eq ( handle , target . Handle ) ) ;
1974
+ return THS_Eq ( handle , target . Handle ) . ToTorchTensor < U > ( ) ;
1927
1975
}
1928
1976
1929
1977
[ DllImport ( "LibTorchSharp" ) ]
@@ -2024,39 +2072,39 @@ internal static ITorchTensor<T> ToTorchTensor<T>(this IntPtr rawTensor)
2024
2072
}
2025
2073
}
2026
2074
2027
- internal static ITorchTensor < T > FromArray < T > ( this IntPtr rawArray )
2075
+ public static ITorchTensor < T > ToTorchTensor < T > ( this T [ ] rawArray )
2028
2076
{
2029
2077
switch ( true )
2030
2078
{
2031
2079
2032
2080
case bool _ when typeof ( T ) == typeof ( byte ) :
2033
2081
{
2034
- return new ByteTensor ( rawArray ) as ITorchTensor < T > ;
2082
+ return ByteTensor . From ( rawArray as byte [ ] , new long [ ] { } , new long [ ] { } ) as ITorchTensor < T > ;
2035
2083
}
2036
2084
2037
2085
case bool _ when typeof ( T ) == typeof ( short ) :
2038
2086
{
2039
- return new ShortTensor ( rawArray ) as ITorchTensor < T > ;
2087
+ return ShortTensor . From ( rawArray as short [ ] , new long [ ] { } , new long [ ] { } ) as ITorchTensor < T > ;
2040
2088
}
2041
2089
2042
2090
case bool _ when typeof ( T ) == typeof ( int ) :
2043
2091
{
2044
- return new IntTensor ( rawArray ) as ITorchTensor < T > ;
2092
+ return IntTensor . From ( rawArray as int [ ] , new long [ ] { } , new long [ ] { } ) as ITorchTensor < T > ;
2045
2093
}
2046
2094
2047
2095
case bool _ when typeof ( T ) == typeof ( long ) :
2048
2096
{
2049
- return new LongTensor ( rawArray ) as ITorchTensor < T > ;
2097
+ return LongTensor . From ( rawArray as long [ ] , new long [ ] { } , new long [ ] { } ) as ITorchTensor < T > ;
2050
2098
}
2051
2099
2052
2100
case bool _ when typeof ( T ) == typeof ( double ) :
2053
2101
{
2054
- return new DoubleTensor ( rawArray ) as ITorchTensor < T > ;
2102
+ return DoubleTensor . From ( rawArray as double [ ] , new long [ ] { } , new long [ ] { } ) as ITorchTensor < T > ;
2055
2103
}
2056
2104
2057
2105
case bool _ when typeof ( T ) == typeof ( float ) :
2058
2106
{
2059
- return new FloatTensor ( rawArray ) as ITorchTensor < T > ;
2107
+ return FloatTensor . From ( rawArray as float [ ] , new long [ ] { } , new long [ ] { } ) as ITorchTensor < T > ;
2060
2108
}
2061
2109
2062
2110
default : throw new NotImplementedException ( $ "Creating tensor of type { typeof ( T ) } is not supported.") ;
0 commit comments