Skip to content

Commit c7a23e5

Browse files
authored
Merge pull request #11 from interesaaat/LibTorchSharpFirstTest
* Renamed Item into DataItem (for some reason it was compiling otherw…
2 parents a4bd5ba + 94ed5a1 commit c7a23e5

File tree

4 files changed

+85
-13
lines changed

4 files changed

+85
-13
lines changed

Examples/MNIST.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ private static void Train(
101101

102102
if (batchId % _logInterval == 0)
103103
{
104-
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.Item}");
104+
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {loss.DataItem}");
105105
}
106106

107107
batchId++;
@@ -127,11 +127,11 @@ private static void Test(
127127
using (var output = model.Forward(data))
128128
using (var loss = NN.LossFunction.NLL(output, target, reduction: NN.Reduction.Sum))
129129
{
130-
testLoss += loss.Item;
130+
testLoss += loss.DataItem;
131131

132132
var pred = output.Argmax(1);
133133

134-
correct += pred.Eq(target).Sum().Item; // Memory leak here
134+
correct += pred.Eq(target).Sum().DataItem; // Memory leak here
135135

136136
data.Dispose();
137137
target.Dispose();

Test/TorchSharp/TorchSharp.cs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public void CreateFloatTensorFromScalar()
124124

125125
using (var tensor = FloatTensor.From(scalar))
126126
{
127-
Assert.AreEqual(tensor.Item, 333);
127+
Assert.AreEqual(tensor.DataItem, 333);
128128
}
129129
}
130130

@@ -135,7 +135,7 @@ public void CreateFloatTensorFromScalar2()
135135

136136
using (var tensor = scalar.ToTorchTensor())
137137
{
138-
Assert.AreEqual(tensor.Item, 333);
138+
Assert.AreEqual(tensor.DataItem, 333);
139139
}
140140
}
141141

@@ -144,7 +144,7 @@ public void InitUniform()
144144
{
145145
using (var tensor = FloatTensor.Zeros(new long[] { 2, 2 }))
146146
{
147-
tensor.InitUniform();
147+
NN.Init.Uniform(tensor);
148148

149149
Assert.IsNotNull(tensor);
150150
}
@@ -281,6 +281,27 @@ public void CreateLinear()
281281
var modules = lin.GetName();
282282
}
283283

284+
[TestMethod]
285+
public void TestGetBiasInLinear()
286+
{
287+
var lin = NN.Module.Linear(1000, 100);
288+
Assert.IsFalse(lin.WithBias);
289+
Assert.ThrowsException<ArgumentNullException>(() => lin.Bias);
290+
}
291+
292+
[TestMethod]
293+
public void TestSetGetBiasInLinear()
294+
{
295+
var lin = NN.Module.Linear(1000, 100, true);
296+
Assert.IsNotNull(lin.Bias);
297+
298+
var bias = FloatTensor.Ones(new long[] { 1000 });
299+
300+
lin.Bias = bias;
301+
302+
Assert.AreEqual(lin.Bias.NumberOfElements, bias.NumberOfElements);
303+
}
304+
284305
[TestMethod]
285306
public void CreateRelu()
286307
{
@@ -323,7 +344,7 @@ public void EvalLossSequence()
323344
var eval = seq.Forward(x);
324345
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
325346

326-
var result = loss.Item;
347+
var result = loss.DataItem;
327348
Assert.IsNotNull(result);
328349
}
329350

@@ -501,7 +522,7 @@ public void TestTraining()
501522
{
502523
var eval = seq.Forward(x);
503524
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
504-
var lossVal = loss.Item;
525+
var lossVal = loss.DataItem;
505526

506527
Assert.IsTrue(lossVal < prevLoss);
507528
prevLoss = lossVal;
@@ -558,7 +579,7 @@ public void TestTrainingAdam()
558579
{
559580
var eval = seq.Forward(x);
560581
var loss = NN.LossFunction.MSE(eval, y, NN.Reduction.Sum);
561-
var lossVal = loss.Item;
582+
var lossVal = loss.DataItem;
562583

563584
Assert.IsTrue(lossVal < prevLoss);
564585
prevLoss = lossVal;

TorchSharp/NN/Linear.cs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,61 @@ namespace TorchSharp.NN
66
{
77
public class Linear : ProvidedModule
88
{
9-
internal Linear(IntPtr handle) : base(handle)
9+
public Linear(IntPtr handle) : base(handle)
1010
{
1111
}
1212

13+
[DllImport("libTorchSharp")]
14+
extern static IntPtr THSNN_linearModule(int input, int output, bool hasBias);
15+
16+
public Linear(int input, int output, bool hasBias = false) : base()
17+
{
18+
handle = new HType(THSNN_linearModule(input, output, hasBias), true);
19+
}
20+
21+
[DllImport("libTorchSharp")]
22+
extern static bool THSNN_linear_with_bias(Module.HType module);
23+
24+
public bool WithBias
25+
{
26+
get { return THSNN_linear_with_bias(handle); }
27+
}
28+
29+
[DllImport("libTorchSharp")]
30+
extern static IntPtr THSNN_linear_get_bias(Module.HType module);
31+
32+
[DllImport("libTorchSharp")]
33+
extern static void THSNN_linear_set_bias(Module.HType module, IntPtr tensor);
34+
35+
public ITorchTensor<float> Bias
36+
{
37+
get
38+
{
39+
var bias = THSNN_linear_get_bias(handle);
40+
if (bias == IntPtr.Zero)
41+
{
42+
throw new ArgumentNullException("Linear module without bias term.");
43+
}
44+
return new FloatTensor(bias);
45+
}
46+
set { THSNN_linear_set_bias(handle, value.Handle); }
47+
}
48+
49+
[DllImport("libTorchSharp")]
50+
extern static IntPtr THSNN_linear_get_weight(Module.HType module);
51+
52+
[DllImport("libTorchSharp")]
53+
extern static void THSNN_linear_set_weight(Module.HType module, IntPtr tensor);
54+
55+
public ITorchTensor<float> Weight
56+
{
57+
get
58+
{
59+
return new FloatTensor(THSNN_linear_get_weight(handle));
60+
}
61+
set { THSNN_linear_set_weight(handle, value.Handle); }
62+
}
63+
1364
[DllImport("libTorchSharp")]
1465
extern static IntPtr THSNN_linearModuleApply(Module.HType module, IntPtr tensor);
1566

TorchSharp/NN/Module.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public partial class Module
1111
/// <summary>
1212
/// Class wrapping PyTorch's module object reference.
1313
/// </summary>
14-
internal sealed class HType : SafeHandle
14+
protected sealed class HType : SafeHandle
1515
{
1616
public HType(IntPtr preexistingHandle, bool ownsHandle) : base(IntPtr.Zero, ownsHandle)
1717
{
@@ -43,7 +43,7 @@ protected override void Dispose(bool disposing)
4343
}
4444
}
4545

46-
internal HType handle;
46+
protected HType handle;
4747

4848
protected bool _isTraining = true;
4949

@@ -96,7 +96,7 @@ static public Sequential Sequential(params Module[] modules)
9696
[DllImport("libTorchSharp")]
9797
extern static IntPtr THSNN_linearModule(int input, int output, bool hasBias);
9898

99-
static public Module Linear(int input, int output, bool hasBias = false)
99+
static public Linear Linear(int input, int output, bool hasBias = false)
100100
{
101101
return new Linear(THSNN_linearModule(input, output, hasBias));
102102
}

0 commit comments

Comments
 (0)