Skip to content

Commit a44be16

Browse files
committed
Added constructors for tensors from scalars
1 parent d205a83 commit a44be16

File tree

3 files changed

+96
-17
lines changed

3 files changed

+96
-17
lines changed

Test/TorchSharp.cs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2-
using System;
32
using TorchSharp.JIT;
43
using TorchSharp.Tensor;
54

@@ -75,6 +74,30 @@ public void CreateFloatTensorFromData()
7574
}
7675
}
7776

77+
[TestMethod]
78+
public void CreateFloatTensorFromData2()
79+
{
80+
81+
var data = new float[1000];
82+
data[100] = 1;
83+
84+
using (var tensor = data.ToTorchTensor())
85+
{
86+
Assert.AreEqual(tensor.Data[100], 1);
87+
}
88+
}
89+
90+
[TestMethod]
91+
public void CreateFloatTensorFromScalar()
92+
{
93+
float scalar = 333.0f;
94+
95+
using (var tensor = FloatTensor.From(scalar))
96+
{
97+
Assert.AreEqual(tensor.Item, 333);
98+
}
99+
}
100+
78101
[TestMethod]
79102
public void ScoreModel()
80103
{

TorchSharp/Generated/TorchTensor.generated.cs

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ public IntPtr Handle
5656
}
5757
}
5858

59+
[DllImport("LibTorchSharp")]
60+
extern static IntPtr THS_new_byteScalar(byte scalar);
61+
62+
public static ByteTensor From(byte scalar)
63+
{
64+
return new ByteTensor(THS_new_byteScalar(scalar));
65+
}
66+
5967
[DllImport("LibTorchSharp")]
6068
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
6169

@@ -288,7 +296,7 @@ public ITorchTensor<byte> Sum()
288296

289297
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
290298
{
291-
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
299+
return THS_Eq(handle, target.Handle).ToTorchTensor<U>();
292300
}
293301

294302
[DllImport("LibTorchSharp")]
@@ -383,6 +391,14 @@ public IntPtr Handle
383391
}
384392
}
385393

394+
[DllImport("LibTorchSharp")]
395+
extern static IntPtr THS_new_shortScalar(short scalar);
396+
397+
public static ShortTensor From(short scalar)
398+
{
399+
return new ShortTensor(THS_new_shortScalar(scalar));
400+
}
401+
386402
[DllImport("LibTorchSharp")]
387403
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
388404

@@ -615,7 +631,7 @@ public ITorchTensor<short> Sum()
615631

616632
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
617633
{
618-
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
634+
return THS_Eq(handle, target.Handle).ToTorchTensor<U>();
619635
}
620636

621637
[DllImport("LibTorchSharp")]
@@ -710,6 +726,14 @@ public IntPtr Handle
710726
}
711727
}
712728

729+
[DllImport("LibTorchSharp")]
730+
extern static IntPtr THS_new_intScalar(int scalar);
731+
732+
public static IntTensor From(int scalar)
733+
{
734+
return new IntTensor(THS_new_intScalar(scalar));
735+
}
736+
713737
[DllImport("LibTorchSharp")]
714738
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
715739

@@ -942,7 +966,7 @@ public ITorchTensor<int> Sum()
942966

943967
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
944968
{
945-
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
969+
return THS_Eq(handle, target.Handle).ToTorchTensor<U>();
946970
}
947971

948972
[DllImport("LibTorchSharp")]
@@ -1037,6 +1061,14 @@ public IntPtr Handle
10371061
}
10381062
}
10391063

1064+
[DllImport("LibTorchSharp")]
1065+
extern static IntPtr THS_new_longScalar(long scalar);
1066+
1067+
public static LongTensor From(long scalar)
1068+
{
1069+
return new LongTensor(THS_new_longScalar(scalar));
1070+
}
1071+
10401072
[DllImport("LibTorchSharp")]
10411073
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
10421074

@@ -1269,7 +1301,7 @@ public ITorchTensor<long> Sum()
12691301

12701302
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
12711303
{
1272-
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
1304+
return THS_Eq(handle, target.Handle).ToTorchTensor<U>();
12731305
}
12741306

12751307
[DllImport("LibTorchSharp")]
@@ -1364,6 +1396,14 @@ public IntPtr Handle
13641396
}
13651397
}
13661398

1399+
[DllImport("LibTorchSharp")]
1400+
extern static IntPtr THS_new_doubleScalar(double scalar);
1401+
1402+
public static DoubleTensor From(double scalar)
1403+
{
1404+
return new DoubleTensor(THS_new_doubleScalar(scalar));
1405+
}
1406+
13671407
[DllImport("LibTorchSharp")]
13681408
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
13691409

@@ -1596,7 +1636,7 @@ public ITorchTensor<double> Sum()
15961636

15971637
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
15981638
{
1599-
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
1639+
return THS_Eq(handle, target.Handle).ToTorchTensor<U>();
16001640
}
16011641

16021642
[DllImport("LibTorchSharp")]
@@ -1691,6 +1731,14 @@ public IntPtr Handle
16911731
}
16921732
}
16931733

1734+
[DllImport("LibTorchSharp")]
1735+
extern static IntPtr THS_new_floatScalar(float scalar);
1736+
1737+
public static FloatTensor From(float scalar)
1738+
{
1739+
return new FloatTensor(THS_new_floatScalar(scalar));
1740+
}
1741+
16941742
[DllImport("LibTorchSharp")]
16951743
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
16961744

@@ -1923,7 +1971,7 @@ public ITorchTensor<float> Sum()
19231971

19241972
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
19251973
{
1926-
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
1974+
return THS_Eq(handle, target.Handle).ToTorchTensor<U>();
19271975
}
19281976

19291977
[DllImport("LibTorchSharp")]
@@ -2024,39 +2072,39 @@ internal static ITorchTensor<T> ToTorchTensor<T>(this IntPtr rawTensor)
20242072
}
20252073
}
20262074

2027-
internal static ITorchTensor<T> FromArray<T>(this IntPtr rawArray)
2075+
public static ITorchTensor<T> ToTorchTensor<T>(this T[] rawArray)
20282076
{
20292077
switch (true)
20302078
{
20312079

20322080
case bool _ when typeof(T) == typeof(byte):
20332081
{
2034-
return new ByteTensor(rawArray) as ITorchTensor<T>;
2082+
return ByteTensor.From(rawArray as byte[], new long[]{}, new long[]{}) as ITorchTensor<T>;
20352083
}
20362084

20372085
case bool _ when typeof(T) == typeof(short):
20382086
{
2039-
return new ShortTensor(rawArray) as ITorchTensor<T>;
2087+
return ShortTensor.From(rawArray as short[], new long[]{}, new long[]{}) as ITorchTensor<T>;
20402088
}
20412089

20422090
case bool _ when typeof(T) == typeof(int):
20432091
{
2044-
return new IntTensor(rawArray) as ITorchTensor<T>;
2092+
return IntTensor.From(rawArray as int[], new long[]{}, new long[]{}) as ITorchTensor<T>;
20452093
}
20462094

20472095
case bool _ when typeof(T) == typeof(long):
20482096
{
2049-
return new LongTensor(rawArray) as ITorchTensor<T>;
2097+
return LongTensor.From(rawArray as long[], new long[]{}, new long[]{}) as ITorchTensor<T>;
20502098
}
20512099

20522100
case bool _ when typeof(T) == typeof(double):
20532101
{
2054-
return new DoubleTensor(rawArray) as ITorchTensor<T>;
2102+
return DoubleTensor.From(rawArray as double[], new long[]{}, new long[]{}) as ITorchTensor<T>;
20552103
}
20562104

20572105
case bool _ when typeof(T) == typeof(float):
20582106
{
2059-
return new FloatTensor(rawArray) as ITorchTensor<T>;
2107+
return FloatTensor.From(rawArray as float[], new long[]{}, new long[]{}) as ITorchTensor<T>;
20602108
}
20612109

20622110
default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");

TorchSharp/Generated/TorchTensor.tt

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ foreach (var type in TorchTypeDef.Types) {
5858
}
5959
}
6060

61+
[DllImport("LibTorchSharp")]
62+
extern static IntPtr THS_new_<#=type.Storage#>Scalar(<#=type.Storage#> scalar);
63+
64+
public static <#=type.Name#>Tensor From(<#=type.Storage#> scalar)
65+
{
66+
return new <#=type.Name#>Tensor(THS_new_<#=type.Storage#>Scalar(scalar));
67+
}
68+
6169
[DllImport("LibTorchSharp")]
6270
extern static IntPtr THS_new(IntPtr rawArray, long[] dimensions, int numDimensions, long[] strides, int numStrides, sbyte type);
6371

@@ -290,7 +298,7 @@ foreach (var type in TorchTypeDef.Types) {
290298

291299
public ITorchTensor<U> Eq<U>(ITorchTensor<U> target)
292300
{
293-
return TensorExtensionMethods.ToTorchTensor<U>(THS_Eq(handle, target.Handle));
301+
return THS_Eq(handle, target.Handle).ToTorchTensor<U>();
294302
}
295303

296304
[DllImport("LibTorchSharp")]
@@ -370,7 +378,7 @@ foreach (var type in TorchTypeDef.Types) {
370378
}
371379
}
372380

373-
internal static ITorchTensor<T> FromArray<T>(this IntPtr rawArray)
381+
internal static ITorchTensor<T> FromArray<T>(this T[] rawArray)
374382
{
375383
switch (true)
376384
{
@@ -379,7 +387,7 @@ foreach (var type in TorchTypeDef.Types) {
379387
#>
380388
case bool _ when typeof(T) == typeof(<#=type.Storage#>):
381389
{
382-
return new <#=type.Name#>Tensor(rawArray) as ITorchTensor<T>;
390+
return <#=type.Name#>Tensor.From(rawArray as <#=type.Storage#>[], new long[]{}, new long[]{}) as ITorchTensor<T>;
383391
}
384392
<#
385393
}

0 commit comments

Comments
 (0)