Skip to content

Commit ecc54e1

Browse files
committed
Add some scalar stuff plus ability to set values using indexing over tensors.
1 parent d677818 commit ecc54e1

File tree

3 files changed

+144
-10
lines changed

3 files changed

+144
-10
lines changed

Test/TorchSharp/TorchSharp.cs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3-
using System.Diagnostics;
43
using System.Linq;
54
using System.Runtime.InteropServices;
65
using TorchSharp.JIT;
7-
using TorchSharp.NN;
86
using TorchSharp.Tensor;
97

108
namespace TorchSharp.Test
@@ -128,7 +126,7 @@ public void CreateFloatTensorFromScalar()
128126

129127
using (var tensor = FloatTensor.From(scalar))
130128
{
131-
Assert.AreEqual(tensor.DataItem<float>(), 333);
129+
Assert.AreEqual(tensor.Item<float>(), 333);
132130
}
133131
}
134132

@@ -139,7 +137,19 @@ public void CreateFloatTensorFromScalar2()
139137

140138
using (var tensor = scalar.ToTorchTensor())
141139
{
142-
Assert.AreEqual(tensor.DataItem<float>(), 333);
140+
Assert.AreEqual(tensor.Item<float>(), 333);
141+
}
142+
}
143+
144+
[TestMethod]
145+
public void TextIndexSet()
146+
{
147+
var tensor = IntTensor.Zeros(new long[] { 2 });
148+
149+
using (var value = 1.ToTorchTensor())
150+
{
151+
tensor[0] = value;
152+
Assert.AreEqual(tensor.Data<int>()[0], 1);
143153
}
144154
}
145155

@@ -436,7 +446,7 @@ public void EvalLossSequence()
436446
var loss = NN.LossFunction.MSE(NN.Reduction.Sum);
437447
var output = loss(eval, y);
438448

439-
var result = output.DataItem<float>();
449+
var result = output.Item<float>();
440450
Assert.IsNotNull(result);
441451
}
442452

@@ -681,7 +691,7 @@ public void TestTraining()
681691
{
682692
var eval = seq.Forward(x);
683693
var output = loss(eval, y);
684-
var lossVal = output.DataItem<float>();
694+
var lossVal = output.Item<float>();
685695

686696
Assert.IsTrue(lossVal < prevLoss);
687697
prevLoss = lossVal;
@@ -739,7 +749,7 @@ public void TestTrainingAdam()
739749
{
740750
var eval = seq.Forward(x);
741751
var output = loss(eval, y);
742-
var lossVal = output.DataItem<float>();
752+
var lossVal = output.Item<float>();
743753

744754
Assert.IsTrue(lossVal < prevLoss);
745755
prevLoss = lossVal;

TorchSharp/Scalar.cs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
using System;
2+
using System.Runtime.InteropServices;
3+
4+
namespace TorchSharp
5+
{
6+
public sealed class Scalar : IDisposable
7+
{
8+
internal IntPtr Handle { get; private set; }
9+
10+
internal Scalar(IntPtr handle)
11+
{
12+
Handle = handle;
13+
}
14+
15+
/// <summary>
16+
/// Releases the storage.
17+
/// </summary>
18+
public void Dispose()
19+
{
20+
Dispose(true);
21+
GC.SuppressFinalize(this);
22+
}
23+
24+
[DllImport("libTorchSharp")]
25+
extern static void THSThorch_dispose_scalar(IntPtr handle);
26+
27+
/// <summary>
28+
/// Implements the .NET Dispose pattern.
29+
/// </summary>
30+
internal void Dispose(bool disposing)
31+
{
32+
if (disposing)
33+
{
34+
THSThorch_dispose_scalar(Handle);
35+
Handle = IntPtr.Zero;
36+
}
37+
}
38+
}
39+
40+
public static class ScalarExtensionMethods
41+
{
42+
[DllImport("libTorchSharp")]
43+
extern static IntPtr THSTorch_btos(byte hanvaluedle);
44+
45+
public static Scalar ToScalar(this byte value)
46+
{
47+
return new Scalar(THSTorch_btos(value));
48+
}
49+
50+
[DllImport("libTorchSharp")]
51+
extern static IntPtr THSTorch_stos(short hanvaluedle);
52+
53+
public static Scalar ToScalar(this short value)
54+
{
55+
return new Scalar(THSTorch_stos(value));
56+
}
57+
58+
[DllImport("libTorchSharp")]
59+
extern static IntPtr THSTorch_itos(int hanvaluedle);
60+
61+
public static Scalar ToScalar(this int value)
62+
{
63+
return new Scalar(THSTorch_itos(value));
64+
}
65+
66+
[DllImport("libTorchSharp")]
67+
extern static IntPtr THSTorch_ltos(long hanvaluedle);
68+
69+
public static Scalar ToScalar(this long value)
70+
{
71+
return new Scalar(THSTorch_ltos(value));
72+
}
73+
74+
[DllImport("libTorchSharp")]
75+
extern static IntPtr THSTorch_ftos(float hanvaluedle);
76+
77+
public static Scalar ToScalar(this float value)
78+
{
79+
return new Scalar(THSTorch_ftos(value));
80+
}
81+
82+
[DllImport("libTorchSharp")]
83+
extern static IntPtr THSTorch_dtos(double hanvaluedle);
84+
85+
public static Scalar ToScalar(this double value)
86+
{
87+
return new Scalar(THSTorch_dtos(value));
88+
}
89+
}
90+
}

TorchSharp/Tensor/TorchTensor.cs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public Span<T> Data<T>()
9898
}
9999
}
100100

101-
public T DataItem<T>()
101+
public T Item<T>()
102102
{
103103
if (NumberOfElements != 1)
104104
{
@@ -107,37 +107,71 @@ public T DataItem<T>()
107107
return Data<T>()[0];
108108
}
109109

110+
[DllImport("libTorchSharp")]
111+
extern static IntPtr THSTensor_item(IntPtr handle);
112+
113+
public Scalar Item()
114+
{
115+
var sptr = THSTensor_item(Handle);
116+
Torch.AssertNoErrors();
117+
return new Scalar(sptr);
118+
}
119+
110120
[DllImport("libTorchSharp")]
111121
extern static IntPtr THSTensor_get1(IntPtr handle, long i1);
112122

123+
[DllImport("libTorchSharp")]
124+
extern static IntPtr THSTensor_set1(IntPtr handle, long i1, IntPtr value);
125+
126+
[System.Runtime.CompilerServices.IndexerName("TensorItems")]
113127
public TorchTensor this[long i1]
114128
{
115-
get
129+
get { return new TorchTensor(THSTensor_get1(handle, i1)); }
130+
set
116131
{
117-
return new TorchTensor(THSTensor_get1(handle, i1));
132+
THSTensor_set1(handle, i1, value.Item().Handle);
133+
Torch.AssertNoErrors();
118134
}
119135
}
120136

121137
[DllImport("libTorchSharp")]
122138
extern static IntPtr THSTensor_get2(IntPtr handle, long i1, long i2);
123139

140+
[DllImport("libTorchSharp")]
141+
extern static IntPtr THSTensor_set2(IntPtr handle, long i1, long i2, IntPtr value);
142+
143+
[System.Runtime.CompilerServices.IndexerName("TensorItems")]
124144
public TorchTensor this[long i1, long i2]
125145
{
126146
get
127147
{
128148
return new TorchTensor(THSTensor_get2(handle, i1, i2));
129149
}
150+
set
151+
{
152+
THSTensor_set2(handle, i1, i2, value.Item().Handle);
153+
Torch.AssertNoErrors();
154+
}
130155
}
131156

132157
[DllImport("libTorchSharp")]
133158
extern static IntPtr THSTensor_get3(IntPtr handle, long i1, long i2, long i3);
134159

160+
[DllImport("libTorchSharp")]
161+
extern static IntPtr THSTensor_set3(IntPtr handle, long i1, long i2, long i3, IntPtr value);
162+
163+
[System.Runtime.CompilerServices.IndexerName("TensorItems")]
135164
public TorchTensor this[long i1, long i2, long i3]
136165
{
137166
get
138167
{
139168
return new TorchTensor(THSTensor_get3(handle, i1, i2, i3));
140169
}
170+
set
171+
{
172+
THSTensor_set3(handle, i1, i2, i3, value.Item().Handle);
173+
Torch.AssertNoErrors();
174+
}
141175
}
142176

143177
[DllImport("libTorchSharp")]

0 commit comments

Comments
 (0)