Skip to content

Commit 79fde58

Browse files
committed
Added device test
1 parent d741422 commit 79fde58

File tree

3 files changed

+47
-44
lines changed

3 files changed

+47
-44
lines changed

Test/TorchSharp.cs

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,36 @@ public void CreateFloatTensorOnes()
1616
[TestMethod]
1717
public void CreateFloatTensorOnesCheckData()
1818
{
19-
unsafe
20-
{
21-
var ones = FloatTensor.Ones(new long[] { 2, 2 });
22-
var data = ones.Data;
19+
var ones = FloatTensor.Ones(new long[] { 2, 2 });
20+
var data = ones.Data;
2321

24-
for (int i = 0; i < 4; i++)
25-
{
26-
Assert.AreEqual(data[i], 1.0);
27-
}
22+
for (int i = 0; i < 4; i++)
23+
{
24+
Assert.AreEqual(data[i], 1.0);
2825
}
2926
}
3027

3128
[TestMethod]
3229
public void CreateIntTensorOnesCheckData()
3330
{
34-
unsafe
35-
{
36-
var ones = IntTensor.Ones(new long[] { 2, 2 });
37-
var data = ones.Data;
31+
var ones = IntTensor.Ones(new long[] { 2, 2 });
32+
var data = ones.Data;
3833

39-
for (int i = 0; i < 4; i++)
40-
{
41-
Assert.AreEqual(data[i], 1);
42-
}
34+
for (int i = 0; i < 4; i++)
35+
{
36+
Assert.AreEqual(data[i], 1);
4337
}
4438
}
4539

40+
[TestMethod]
41+
public void CreateFloatTensorCheckDevice()
42+
{
43+
var ones = FloatTensor.Ones(new long[] { 2, 2 });
44+
var device = ones.Device;
45+
46+
Assert.AreEqual(ones.Device, "cpu");
47+
}
48+
4649
[TestMethod]
4750
public void ScoreModel()
4851
{

TorchSharp/Generated/TorchTensor.generated.cs

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ public byte Item
154154
}
155155

156156
[DllImport("LibTorchSharp")]
157-
extern static string THS_device(HType handle);
157+
extern static string THS_deviceType(HType handle);
158158

159159
public string Device
160160
{
161161
get
162162
{
163-
return THS_device(handle);
163+
return THS_deviceType(handle);
164164
}
165165
}
166166

@@ -208,7 +208,7 @@ public long GetTensorStride (int dim)
208208
/// <summary>
209209
/// Create a new tensor filled with ones
210210
/// </summary>
211-
static public ITorchTensor<byte> Ones(long[] size, string device = "cpu:0", bool requiresGrad = false)
211+
static public ITorchTensor<byte> Ones(long[] size, string device = "cpu", bool requiresGrad = false)
212212
{
213213
unsafe
214214
{
@@ -225,7 +225,7 @@ static public ITorchTensor<byte> Ones(long[] size, string device = "cpu:0", bool
225225
/// <summary>
226226
/// Create a new tensor filled with ones
227227
/// </summary>
228-
static public ITorchTensor<byte> RandomN(long[] size, string device = "cpu:0", bool requiresGrad = false)
228+
static public ITorchTensor<byte> RandomN(long[] size, string device = "cpu", bool requiresGrad = false)
229229
{
230230
unsafe
231231
{
@@ -438,13 +438,13 @@ public short Item
438438
}
439439

440440
[DllImport("LibTorchSharp")]
441-
extern static string THS_device(HType handle);
441+
extern static string THS_deviceType(HType handle);
442442

443443
public string Device
444444
{
445445
get
446446
{
447-
return THS_device(handle);
447+
return THS_deviceType(handle);
448448
}
449449
}
450450

@@ -492,7 +492,7 @@ public long GetTensorStride (int dim)
492492
/// <summary>
493493
/// Create a new tensor filled with ones
494494
/// </summary>
495-
static public ITorchTensor<short> Ones(long[] size, string device = "cpu:0", bool requiresGrad = false)
495+
static public ITorchTensor<short> Ones(long[] size, string device = "cpu", bool requiresGrad = false)
496496
{
497497
unsafe
498498
{
@@ -509,7 +509,7 @@ static public ITorchTensor<short> Ones(long[] size, string device = "cpu:0", boo
509509
/// <summary>
510510
/// Create a new tensor filled with ones
511511
/// </summary>
512-
static public ITorchTensor<short> RandomN(long[] size, string device = "cpu:0", bool requiresGrad = false)
512+
static public ITorchTensor<short> RandomN(long[] size, string device = "cpu", bool requiresGrad = false)
513513
{
514514
unsafe
515515
{
@@ -722,13 +722,13 @@ public int Item
722722
}
723723

724724
[DllImport("LibTorchSharp")]
725-
extern static string THS_device(HType handle);
725+
extern static string THS_deviceType(HType handle);
726726

727727
public string Device
728728
{
729729
get
730730
{
731-
return THS_device(handle);
731+
return THS_deviceType(handle);
732732
}
733733
}
734734

@@ -776,7 +776,7 @@ public long GetTensorStride (int dim)
776776
/// <summary>
777777
/// Create a new tensor filled with ones
778778
/// </summary>
779-
static public ITorchTensor<int> Ones(long[] size, string device = "cpu:0", bool requiresGrad = false)
779+
static public ITorchTensor<int> Ones(long[] size, string device = "cpu", bool requiresGrad = false)
780780
{
781781
unsafe
782782
{
@@ -793,7 +793,7 @@ static public ITorchTensor<int> Ones(long[] size, string device = "cpu:0", bool
793793
/// <summary>
794794
/// Create a new tensor filled with ones
795795
/// </summary>
796-
static public ITorchTensor<int> RandomN(long[] size, string device = "cpu:0", bool requiresGrad = false)
796+
static public ITorchTensor<int> RandomN(long[] size, string device = "cpu", bool requiresGrad = false)
797797
{
798798
unsafe
799799
{
@@ -1006,13 +1006,13 @@ public long Item
10061006
}
10071007

10081008
[DllImport("LibTorchSharp")]
1009-
extern static string THS_device(HType handle);
1009+
extern static string THS_deviceType(HType handle);
10101010

10111011
public string Device
10121012
{
10131013
get
10141014
{
1015-
return THS_device(handle);
1015+
return THS_deviceType(handle);
10161016
}
10171017
}
10181018

@@ -1060,7 +1060,7 @@ public long GetTensorStride (int dim)
10601060
/// <summary>
10611061
/// Create a new tensor filled with ones
10621062
/// </summary>
1063-
static public ITorchTensor<long> Ones(long[] size, string device = "cpu:0", bool requiresGrad = false)
1063+
static public ITorchTensor<long> Ones(long[] size, string device = "cpu", bool requiresGrad = false)
10641064
{
10651065
unsafe
10661066
{
@@ -1077,7 +1077,7 @@ static public ITorchTensor<long> Ones(long[] size, string device = "cpu:0", bool
10771077
/// <summary>
10781078
/// Create a new tensor filled with ones
10791079
/// </summary>
1080-
static public ITorchTensor<long> RandomN(long[] size, string device = "cpu:0", bool requiresGrad = false)
1080+
static public ITorchTensor<long> RandomN(long[] size, string device = "cpu", bool requiresGrad = false)
10811081
{
10821082
unsafe
10831083
{
@@ -1290,13 +1290,13 @@ public double Item
12901290
}
12911291

12921292
[DllImport("LibTorchSharp")]
1293-
extern static string THS_device(HType handle);
1293+
extern static string THS_deviceType(HType handle);
12941294

12951295
public string Device
12961296
{
12971297
get
12981298
{
1299-
return THS_device(handle);
1299+
return THS_deviceType(handle);
13001300
}
13011301
}
13021302

@@ -1344,7 +1344,7 @@ public long GetTensorStride (int dim)
13441344
/// <summary>
13451345
/// Create a new tensor filled with ones
13461346
/// </summary>
1347-
static public ITorchTensor<double> Ones(long[] size, string device = "cpu:0", bool requiresGrad = false)
1347+
static public ITorchTensor<double> Ones(long[] size, string device = "cpu", bool requiresGrad = false)
13481348
{
13491349
unsafe
13501350
{
@@ -1361,7 +1361,7 @@ static public ITorchTensor<double> Ones(long[] size, string device = "cpu:0", bo
13611361
/// <summary>
13621362
/// Create a new tensor filled with ones
13631363
/// </summary>
1364-
static public ITorchTensor<double> RandomN(long[] size, string device = "cpu:0", bool requiresGrad = false)
1364+
static public ITorchTensor<double> RandomN(long[] size, string device = "cpu", bool requiresGrad = false)
13651365
{
13661366
unsafe
13671367
{
@@ -1574,13 +1574,13 @@ public float Item
15741574
}
15751575

15761576
[DllImport("LibTorchSharp")]
1577-
extern static string THS_device(HType handle);
1577+
extern static string THS_deviceType(HType handle);
15781578

15791579
public string Device
15801580
{
15811581
get
15821582
{
1583-
return THS_device(handle);
1583+
return THS_deviceType(handle);
15841584
}
15851585
}
15861586

@@ -1628,7 +1628,7 @@ public long GetTensorStride (int dim)
16281628
/// <summary>
16291629
/// Create a new tensor filled with ones
16301630
/// </summary>
1631-
static public ITorchTensor<float> Ones(long[] size, string device = "cpu:0", bool requiresGrad = false)
1631+
static public ITorchTensor<float> Ones(long[] size, string device = "cpu", bool requiresGrad = false)
16321632
{
16331633
unsafe
16341634
{
@@ -1645,7 +1645,7 @@ static public ITorchTensor<float> Ones(long[] size, string device = "cpu:0", boo
16451645
/// <summary>
16461646
/// Create a new tensor filled with ones
16471647
/// </summary>
1648-
static public ITorchTensor<float> RandomN(long[] size, string device = "cpu:0", bool requiresGrad = false)
1648+
static public ITorchTensor<float> RandomN(long[] size, string device = "cpu", bool requiresGrad = false)
16491649
{
16501650
unsafe
16511651
{

TorchSharp/Generated/TorchTensor.tt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,13 @@ foreach (var type in TorchTypeDef.Types) {
161161
}
162162

163163
[DllImport("LibTorchSharp")]
164-
extern static string THS_device(HType handle);
164+
extern static string THS_deviceType(HType handle);
165165

166166
public string Device
167167
{
168168
get
169169
{
170-
return THS_device(handle);
170+
return THS_deviceType(handle);
171171
}
172172
}
173173

@@ -215,7 +215,7 @@ foreach (var type in TorchTypeDef.Types) {
215215
/// <summary>
216216
/// Create a new tensor filled with ones
217217
/// </summary>
218-
static public ITorchTensor<<#=type.Storage#>> Ones(long[] size, string device = "cpu:0", bool requiresGrad = false)
218+
static public ITorchTensor<<#=type.Storage#>> Ones(long[] size, string device = "cpu", bool requiresGrad = false)
219219
{
220220
unsafe
221221
{
@@ -232,7 +232,7 @@ foreach (var type in TorchTypeDef.Types) {
232232
/// <summary>
233233
/// Create a new tensor filled with ones
234234
/// </summary>
235-
static public ITorchTensor<<#=type.Storage#>> RandomN(long[] size, string device = "cpu:0", bool requiresGrad = false)
235+
static public ITorchTensor<<#=type.Storage#>> RandomN(long[] size, string device = "cpu", bool requiresGrad = false)
236236
{
237237
unsafe
238238
{

0 commit comments

Comments
 (0)