Skip to content

Commit 6f2e094

Browse files
committed
Merge branch 'LibTorchSharpFirstTest' of https://github.com/interesaaat/TorchSharp into LibTorchSharpFirstTest
2 parents 7211cf4 + a119bdf commit 6f2e094

File tree

4 files changed

+151
-33
lines changed

4 files changed

+151
-33
lines changed

Test/TorchSharp.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,40 @@ public void CreateFloatTensorFromScalar2()
113113
}
114114
}
115115

116+
[TestMethod]
117+
public void CopyCpuToCuda()
118+
{
119+
var cpu = FloatTensor.Ones(new long[] { 2, 2 });
120+
Assert.AreEqual(cpu.Device, "cpu");
121+
122+
var cuda = cpu.Cuda();
123+
Assert.AreEqual(cuda.Device, "cuda");
124+
125+
// Copy back to CPU to inspect the elements
126+
cpu = cuda.Cpu();
127+
var data = cpu.Data;
128+
for (int i = 0; i < 4; i++)
129+
{
130+
Assert.AreEqual(data[i], 1);
131+
}
132+
}
133+
134+
[TestMethod]
135+
public void CopyCudaToCpu()
136+
{
137+
var cuda = FloatTensor.Ones(new long[] { 2, 2 }, "cuda");
138+
Assert.AreEqual(cuda.Device, "cuda");
139+
140+
var cpu = cuda.Cpu();
141+
Assert.AreEqual(cpu.Device, "cpu");
142+
143+
var data = cpu.Data;
144+
for (int i = 0; i < 4; i++)
145+
{
146+
Assert.AreEqual(data[i], 1);
147+
}
148+
}
149+
116150
[TestMethod]
117151
public void ScoreModel()
118152
{

TorchSharp/Generated/TorchTensor.generated.cs

Lines changed: 97 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
1-

2-
3-
4-
5-
using System;
1+
using System;
62
using System.Linq;
73
using System.Runtime.InteropServices;
84
using System.Text;
95

106
namespace TorchSharp.Tensor {
117

12-
138
/// <summary>
149
/// Tensor of type Byte.
1510
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -181,6 +176,22 @@ public string Device
181176
}
182177
}
183178

179+
[DllImport("LibTorchSharp")]
180+
extern static IntPtr THS_cpu(IntPtr handle);
181+
182+
public ITorchTensor<byte> Cpu()
183+
{
184+
return new ByteTensor(THS_cpu(handle));
185+
}
186+
187+
[DllImport("LibTorchSharp")]
188+
extern static IntPtr THS_cuda(IntPtr handle);
189+
190+
public ITorchTensor<byte> Cuda()
191+
{
192+
return new ByteTensor(THS_cuda(handle));
193+
}
194+
184195
/// <summary>
185196
/// Retrieves the size of the specified dimension in the tensor.
186197
/// </summary>
@@ -344,7 +355,6 @@ public override string ToString()
344355
return sb.ToString();
345356
}
346357
}
347-
348358
/// <summary>
349359
/// Tensor of type Short.
350360
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -516,6 +526,22 @@ public string Device
516526
}
517527
}
518528

529+
[DllImport("LibTorchSharp")]
530+
extern static IntPtr THS_cpu(IntPtr handle);
531+
532+
public ITorchTensor<short> Cpu()
533+
{
534+
return new ShortTensor(THS_cpu(handle));
535+
}
536+
537+
[DllImport("LibTorchSharp")]
538+
extern static IntPtr THS_cuda(IntPtr handle);
539+
540+
public ITorchTensor<short> Cuda()
541+
{
542+
return new ShortTensor(THS_cuda(handle));
543+
}
544+
519545
/// <summary>
520546
/// Retrieves the size of the specified dimension in the tensor.
521547
/// </summary>
@@ -679,7 +705,6 @@ public override string ToString()
679705
return sb.ToString();
680706
}
681707
}
682-
683708
/// <summary>
684709
/// Tensor of type Int.
685710
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -851,6 +876,22 @@ public string Device
851876
}
852877
}
853878

879+
[DllImport("LibTorchSharp")]
880+
extern static IntPtr THS_cpu(IntPtr handle);
881+
882+
public ITorchTensor<int> Cpu()
883+
{
884+
return new IntTensor(THS_cpu(handle));
885+
}
886+
887+
[DllImport("LibTorchSharp")]
888+
extern static IntPtr THS_cuda(IntPtr handle);
889+
890+
public ITorchTensor<int> Cuda()
891+
{
892+
return new IntTensor(THS_cuda(handle));
893+
}
894+
854895
/// <summary>
855896
/// Retrieves the size of the specified dimension in the tensor.
856897
/// </summary>
@@ -1014,7 +1055,6 @@ public override string ToString()
10141055
return sb.ToString();
10151056
}
10161057
}
1017-
10181058
/// <summary>
10191059
/// Tensor of type Long.
10201060
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -1186,6 +1226,22 @@ public string Device
11861226
}
11871227
}
11881228

1229+
[DllImport("LibTorchSharp")]
1230+
extern static IntPtr THS_cpu(IntPtr handle);
1231+
1232+
public ITorchTensor<long> Cpu()
1233+
{
1234+
return new LongTensor(THS_cpu(handle));
1235+
}
1236+
1237+
[DllImport("LibTorchSharp")]
1238+
extern static IntPtr THS_cuda(IntPtr handle);
1239+
1240+
public ITorchTensor<long> Cuda()
1241+
{
1242+
return new LongTensor(THS_cuda(handle));
1243+
}
1244+
11891245
/// <summary>
11901246
/// Retrieves the size of the specified dimension in the tensor.
11911247
/// </summary>
@@ -1349,7 +1405,6 @@ public override string ToString()
13491405
return sb.ToString();
13501406
}
13511407
}
1352-
13531408
/// <summary>
13541409
/// Tensor of type Double.
13551410
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -1521,6 +1576,22 @@ public string Device
15211576
}
15221577
}
15231578

1579+
[DllImport("LibTorchSharp")]
1580+
extern static IntPtr THS_cpu(IntPtr handle);
1581+
1582+
public ITorchTensor<double> Cpu()
1583+
{
1584+
return new DoubleTensor(THS_cpu(handle));
1585+
}
1586+
1587+
[DllImport("LibTorchSharp")]
1588+
extern static IntPtr THS_cuda(IntPtr handle);
1589+
1590+
public ITorchTensor<double> Cuda()
1591+
{
1592+
return new DoubleTensor(THS_cuda(handle));
1593+
}
1594+
15241595
/// <summary>
15251596
/// Retrieves the size of the specified dimension in the tensor.
15261597
/// </summary>
@@ -1684,7 +1755,6 @@ public override string ToString()
16841755
return sb.ToString();
16851756
}
16861757
}
1687-
16881758
/// <summary>
16891759
/// Tensor of type Float.
16901760
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -1856,6 +1926,22 @@ public string Device
18561926
}
18571927
}
18581928

1929+
[DllImport("LibTorchSharp")]
1930+
extern static IntPtr THS_cpu(IntPtr handle);
1931+
1932+
public ITorchTensor<float> Cpu()
1933+
{
1934+
return new FloatTensor(THS_cpu(handle));
1935+
}
1936+
1937+
[DllImport("LibTorchSharp")]
1938+
extern static IntPtr THS_cuda(IntPtr handle);
1939+
1940+
public ITorchTensor<float> Cuda()
1941+
{
1942+
return new FloatTensor(THS_cuda(handle));
1943+
}
1944+
18591945
/// <summary>
18601946
/// Retrieves the size of the specified dimension in the tensor.
18611947
/// </summary>
@@ -2019,7 +2105,6 @@ public override string ToString()
20192105
return sb.ToString();
20202106
}
20212107
}
2022-
20232108

20242109
public enum ATenScalarMapping : sbyte
20252110
{
@@ -2037,37 +2122,30 @@ internal static ITorchTensor<T> ToTorchTensor<T>(this IntPtr rawTensor)
20372122
{
20382123
switch (true)
20392124
{
2040-
20412125
case bool _ when typeof(T) == typeof(byte):
20422126
{
20432127
return new ByteTensor(rawTensor) as ITorchTensor<T>;
20442128
}
2045-
20462129
case bool _ when typeof(T) == typeof(short):
20472130
{
20482131
return new ShortTensor(rawTensor) as ITorchTensor<T>;
20492132
}
2050-
20512133
case bool _ when typeof(T) == typeof(int):
20522134
{
20532135
return new IntTensor(rawTensor) as ITorchTensor<T>;
20542136
}
2055-
20562137
case bool _ when typeof(T) == typeof(long):
20572138
{
20582139
return new LongTensor(rawTensor) as ITorchTensor<T>;
20592140
}
2060-
20612141
case bool _ when typeof(T) == typeof(double):
20622142
{
20632143
return new DoubleTensor(rawTensor) as ITorchTensor<T>;
20642144
}
2065-
20662145
case bool _ when typeof(T) == typeof(float):
20672146
{
20682147
return new FloatTensor(rawTensor) as ITorchTensor<T>;
20692148
}
2070-
20712149
default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
20722150
}
20732151
}
@@ -2076,37 +2154,30 @@ public static ITorchTensor<T> ToTorchTensor<T>(this T[] rawArray, long[] dimensi
20762154
{
20772155
switch (true)
20782156
{
2079-
20802157
case bool _ when typeof(T) == typeof(byte):
20812158
{
20822159
return ByteTensor.From(rawArray as byte[], dimensions) as ITorchTensor<T>;
20832160
}
2084-
20852161
case bool _ when typeof(T) == typeof(short):
20862162
{
20872163
return ShortTensor.From(rawArray as short[], dimensions) as ITorchTensor<T>;
20882164
}
2089-
20902165
case bool _ when typeof(T) == typeof(int):
20912166
{
20922167
return IntTensor.From(rawArray as int[], dimensions) as ITorchTensor<T>;
20932168
}
2094-
20952169
case bool _ when typeof(T) == typeof(long):
20962170
{
20972171
return LongTensor.From(rawArray as long[], dimensions) as ITorchTensor<T>;
20982172
}
2099-
21002173
case bool _ when typeof(T) == typeof(double):
21012174
{
21022175
return DoubleTensor.From(rawArray as double[], dimensions) as ITorchTensor<T>;
21032176
}
2104-
21052177
case bool _ when typeof(T) == typeof(float):
21062178
{
21072179
return FloatTensor.From(rawArray as float[], dimensions) as ITorchTensor<T>;
21082180
}
2109-
21102181
default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
21112182
}
21122183
}
@@ -2115,37 +2186,30 @@ public static ITorchTensor<T> ToTorchTensor<T>(this T scalar)
21152186
{
21162187
switch (true)
21172188
{
2118-
21192189
case bool _ when typeof(T) == typeof(byte):
21202190
{
21212191
return ByteTensor.From((byte)(object)scalar) as ITorchTensor<T>;
21222192
}
2123-
21242193
case bool _ when typeof(T) == typeof(short):
21252194
{
21262195
return ShortTensor.From((short)(object)scalar) as ITorchTensor<T>;
21272196
}
2128-
21292197
case bool _ when typeof(T) == typeof(int):
21302198
{
21312199
return IntTensor.From((int)(object)scalar) as ITorchTensor<T>;
21322200
}
2133-
21342201
case bool _ when typeof(T) == typeof(long):
21352202
{
21362203
return LongTensor.From((long)(object)scalar) as ITorchTensor<T>;
21372204
}
2138-
21392205
case bool _ when typeof(T) == typeof(double):
21402206
{
21412207
return DoubleTensor.From((double)(object)scalar) as ITorchTensor<T>;
21422208
}
2143-
21442209
case bool _ when typeof(T) == typeof(float):
21452210
{
21462211
return FloatTensor.From((float)(object)scalar) as ITorchTensor<T>;
21472212
}
2148-
21492213
default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
21502214
}
21512215
}

TorchSharp/Generated/TorchTensor.tt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,22 @@ foreach (var type in TorchTypeDef.Types) {
183183
}
184184
}
185185

186+
[DllImport("LibTorchSharp")]
187+
extern static IntPtr THS_cpu(IntPtr handle);
188+
189+
public ITorchTensor<<#=type.Storage#>> Cpu()
190+
{
191+
return new <#=type.Name#>Tensor(THS_cpu(handle));
192+
}
193+
194+
[DllImport("LibTorchSharp")]
195+
extern static IntPtr THS_cuda(IntPtr handle);
196+
197+
public ITorchTensor<<#=type.Storage#>> Cuda()
198+
{
199+
return new <#=type.Name#>Tensor(THS_cuda(handle));
200+
}
201+
186202
/// <summary>
187203
/// Retrieves the size of the specified dimension in the tensor.
188204
/// </summary>

TorchSharp/Tensor/ITorchTensor.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ public interface ITorchTensor<T> : IDisposable
1414

1515
string Device { get; }
1616

17+
ITorchTensor<T> Cpu();
18+
19+
ITorchTensor<T> Cuda();
20+
1721
Span<T> Data { get; }
1822

1923
T Item { get; }

0 commit comments

Comments
 (0)