Skip to content

Commit a4bd5ba

Browse files
authored
Merge pull request #10 from interesaaat/LibTorchSharpFirstTest
2 parents 7a6be72 + 4828aa2 commit a4bd5ba

File tree

5 files changed

+301
-30
lines changed

5 files changed

+301
-30
lines changed

TorchSharp/NN/Init.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using System;
2+
using System.Runtime.InteropServices;
3+
using TorchSharp.Tensor;
4+
5+
namespace TorchSharp.NN
6+
{
7+
public static class Init
8+
{
9+
[DllImport("libTorchSharp")]
10+
extern static void THSNN_initUniform(IntPtr src, double low, double high);
11+
12+
public static void Uniform<T>(ITorchTensor<T> tensor, double low = 0, double high = 1)
13+
{
14+
THSNN_initUniform(tensor.Handle, low, high);
15+
}
16+
17+
[DllImport("libTorchSharp")]
18+
extern static void THSNN_initKaimingUniform(IntPtr src, double a);
19+
20+
public static void KaimingUniform<T>(ITorchTensor<T> tensor, double a = 0)
21+
{
22+
THSNN_initKaimingUniform(tensor.Handle, a);
23+
}
24+
25+
public static (long fanIn, long fanOut) CalculateFanInAndFanOut<T>(ITorchTensor<T> tensor)
26+
{
27+
var dimensions = tensor.Dimensions;
28+
29+
if (dimensions < 2)
30+
{
31+
throw new ArgumentException("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");
32+
}
33+
34+
var shape = tensor.Shape;
35+
// Linear
36+
if (dimensions == 2)
37+
{
38+
return (shape[1], shape[2]);
39+
}
40+
else
41+
{
42+
var numInputFMaps = tensor.Shape[1];
43+
var numOutputFMaps = tensor.Shape[0];
44+
var receptiveFieldSize = tensor[0, 0].NumberOfElements;
45+
46+
return (numInputFMaps * receptiveFieldSize, numOutputFMaps * receptiveFieldSize);
47+
}
48+
}
49+
}
50+
}

TorchSharp/NN/LossFunction.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ public static ITorchTensor<float> NLL<T, U>(ITorchTensor<T> src, ITorchTensor<U>
3434
}
3535

3636
[DllImport("libTorchSharp")]
37-
extern static IntPtr THSNN_lossPoissonNLL(IntPtr srct, IntPtr trgt, bool logInput, bool full, double eps, long reduction);
37+
extern static IntPtr THSNN_lossPoissonNLL(IntPtr srct, IntPtr trgt, bool logInput, bool full, float eps, long reduction);
3838

39-
public static ITorchTensor<float> PoissonNLL<T, U>(ITorchTensor<T> src, ITorchTensor<U> target, bool logInput = true, bool full = false, double eps = 1e-8, Reduction reduction = Reduction.Mean)
39+
public static ITorchTensor<float> PoissonNLL<T, U>(ITorchTensor<T> src, ITorchTensor<U> target, bool logInput = true, bool full = false, float eps = 1e-8f, Reduction reduction = Reduction.Mean)
4040
{
4141
return new FloatTensor(THSNN_lossPoissonNLL(src.Handle, target.Handle, logInput, full, eps, (long)reduction));
4242
}

TorchSharp/Tensor/ITorchTensor.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@ public interface ITorchTensor<T> : IDisposable
1616

1717
Span<T> Data { get; }
1818

19-
T Item { get; }
19+
T DataItem { get; }
20+
21+
ITorchTensor<T> this[long i1] { get; }
22+
23+
ITorchTensor<T> this[long i1, long i2] { get; }
24+
25+
ITorchTensor<T> this[long i1, long i2, long i3] { get; }
2026

2127
bool IsSparse { get; }
2228

0 commit comments

Comments
 (0)