Skip to content

Commit c5446f2

Browse files
committed
Bind Random
1 parent 7a0a5b3 commit c5446f2

File tree

4 files changed

+184
-37
lines changed

4 files changed

+184
-37
lines changed

Tester/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public static void Main (string [] args)
1515
var b = new FloatTensor (10);
1616
b.Fill (30);
1717
Dump (b);
18-
x.Random (new THRandom (), 10);
18+
x.Random (new RandomGenerator (), 10);
1919
FloatTensor.Add (x, 100, b);
2020
Dump (x);
2121
Dump (b);

TorchSharp/THRandom.cs

Lines changed: 155 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ namespace TorchSharp {
55
/// <summary>
66
/// Random class
77
/// </summary>
8-
/// <remarks>
8+
/// <remarks>
99
/// Behind the scenes this is the THGenerator API and THRandom combined into
10-
/// one as THRandom are just convenience methods on top of THGenerator.
11-
/// </remarks>
12-
public class THRandom : IDisposable {
10+
/// one as THRandom are just convenience methods on top of THGenerator.
11+
/// </remarks>
12+
public class RandomGenerator : IDisposable {
1313
internal IntPtr handle;
1414

1515
[DllImport ("caffe2")]
@@ -18,7 +18,7 @@ public class THRandom : IDisposable {
1818
[DllImport ("caffe2")]
1919
extern static IntPtr THGenerator_new ();
2020

21-
public THRandom ()
21+
public RandomGenerator ()
2222
{
2323
handle = THGenerator_new ();
2424
}
@@ -36,7 +36,7 @@ protected virtual void Dispose (bool disposing)
3636
}
3737
}
3838

39-
~THRandom ()
39+
~RandomGenerator()
4040
{
4141
Dispose (false);
4242
}
@@ -45,7 +45,154 @@ protected virtual void Dispose (bool disposing)
4545
public void Dispose ()
4646
{
4747
Dispose (true);
48-
GC.SuppressFinalize(this);
48+
GC.SuppressFinalize (this);
4949
}
50+
51+
[DllImport ("caffe2")]
52+
extern static ulong THRandom_seed (IntPtr handle);
53+
54+
/// <summary>
55+
/// Initializes the random number generator from /dev/urandom or in Windows with the current time.
56+
/// </summary>
57+
/// <returns>The random seed.</returns>
58+
public ulong InitRandomSeed () => THRandom_seed (handle);
59+
60+
[DllImport ("caffe2")]
61+
extern static void THRandom_manualSeed (IntPtr handle, ulong seed);
62+
63+
/// <summary>
64+
/// Initializes the random number generator with the given seed.
65+
/// </summary>
66+
/// <param name="seed">Seed.</param>
67+
public void InitWithSeed (ulong seed) => THRandom_manualSeed (handle, seed);
68+
69+
70+
[DllImport ("caffe2")]
71+
extern static ulong THRandom_initialSeed (IntPtr handle);
72+
73+
/// <summary>
74+
/// Returns the starting seed used.
75+
/// </summary>
76+
/// <value>The initial seed.</value>
77+
public ulong InitialSeed => THRandom_initialSeed (handle);
78+
79+
[DllImport ("caffe2")]
80+
extern static ulong THRandom_random (IntPtr handle);
81+
82+
/// <summary>
83+
/// Generates a uniform 32 bits integer.
84+
/// </summary>
85+
/// <returns>UInt32 random value.</returns>
86+
public uint NextUInt32 () => (uint)THRandom_random (handle);
87+
88+
[DllImport ("caffe2")]
89+
extern static ulong THRandom_random64 (IntPtr handle);
90+
91+
/// <summary>
92+
/// Generates a uniform 64 bits integer.
93+
/// </summary>
94+
/// <returns>UInt64 random value.</returns>
95+
public ulong NextUInt64 () => THRandom_random64 (handle);
96+
97+
[DllImport ("caffe2")]
98+
extern static double THRandom_standard_uniform (IntPtr handle);
99+
100+
/// <summary>
101+
/// Generates a uniform random double on [0,1).
102+
/// </summary>
103+
/// <returns>Generates a uniform random double on [0,1).</returns>
104+
public double NextDouble () => THRandom_standard_uniform (handle);
105+
106+
[DllImport ("caffe2")]
107+
extern static double THRandom_uniform (IntPtr handle, double a, double b);
108+
109+
/// <summary>
110+
/// Generates a uniform random double on [a,b).
111+
/// </summary>
112+
/// <returns>Generates a uniform random double on [a, b).</returns>
113+
public double NextDouble (double a, double b) => THRandom_uniform (handle, a, b);
114+
115+
[DllImport ("caffe2")]
116+
extern static float THRandom_uniformFloat (IntPtr handle, float a, float b);
117+
118+
/// <summary>
119+
/// Generates a uniform random float on [a,b).
120+
/// </summary>
121+
/// <returns>Generates a uniform random float on [a, b).</returns>
122+
public double NextFloat (float a, float b) => THRandom_uniformFloat (handle, a, b);
123+
124+
[DllImport ("caffe2")]
125+
extern static double THRandom_normal (IntPtr handle, double mean, double stddev);
126+
127+
/// <summary>
128+
/// Generates a random number from a normal distribution.
129+
/// </summary>
130+
/// <param name="mean">Mean for the distribution</param>
131+
/// <param name="stdev">Stanard deviation for the distribution, > 0 </param>
132+
/// <returns>Generates a number for the normal distribution.</returns>
133+
public double NextNormalDouble (double mean, double stdev) => THRandom_normal (handle, mean, stdev);
134+
135+
[DllImport ("caffe2")]
136+
extern static double THRandom_exponential (IntPtr handle, double lambda);
137+
138+
/// <summary>
139+
/// Generates a random number from an exponential distribution.
140+
/// </summary>
141+
/// <param name="lambda">Must be a positive number</param>
142+
/// <remarks>
143+
/// The density is $p(x) = lambda * exp(-lambda * x)$, where lambda is a positive number.
144+
/// </remarks>
145+
public double NextExponentialDouble (double lambda) => THRandom_exponential (handle, lambda);
146+
147+
[DllImport ("caffe2")]
148+
extern static double THRandom_cauchy (IntPtr handle, double median, double sigma);
149+
150+
/// <summary>
151+
/// Returns a random number from a Cauchy distribution.
152+
/// </summary>
153+
/// <param name="lambda">Must be a positive number</param>
154+
/// <remarks>
155+
/// The Cauchy density is $p(x) = sigma/(pi*(sigma^2 + (x-median)^2))$
156+
/// </remarks>
157+
public double NextCauchyDouble (double median, double sigma) => THRandom_cauchy (handle, median, sigma);
158+
159+
[DllImport ("caffe2")]
160+
extern static double THRandom_logNormal (IntPtr handle, double mean, double stddev);
161+
162+
/// <summary>
163+
/// Generates a random number from a log-normal distribution.
164+
/// </summary>
165+
/// <param name="mean">&gt; 0 is the mean of the log-normal distribution</param>
166+
/// <param name="stddev">is its standard deviation.</param>
167+
public double NextLogNormalDouble (double mean, double stddev) => THRandom_logNormal (handle, mean, stddev);
168+
169+
[DllImport ("caffe2")]
170+
extern static double THRandom_geometric (IntPtr handle, double p);
171+
172+
/// <summary>
173+
/// Generates a random number from a geometric distribution.
174+
/// </summary>
175+
/// <remarks>
176+
/// It returns an integer i, where p(i) = (1-p) * p^(i-1).
177+
/// p must satisfy $0 &lt; p &lt; 1
178+
/// </remarks>
179+
public double NextGeometricDouble (double mean, double p) => THRandom_geometric (handle, p);
180+
181+
[DllImport ("caffe2")]
182+
extern static double THRandom_bernoulli (IntPtr handle, double p);
183+
184+
/// <summary>
185+
/// Returns true with double probability $p$ and false with probability 1-p (p &gt; 0).
186+
/// </summary>
187+
public double NextBernoulliDouble (double mean, double p) => THRandom_bernoulli (handle, p);
188+
189+
[DllImport ("caffe2")]
190+
extern static float THRandom_bernoulliFloat (IntPtr handle, float p);
191+
192+
/// <summary>
193+
/// Returns true with double probability $p$ and false with probability 1-p (p &gt; 0).
194+
/// </summary>
195+
public float NextBernoulliDouble (double mean, float p) => THRandom_bernoulliFloat (handle, p);
196+
50197
}
51-
}
198+
}

TorchSharp/TypeGeneration.cs

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ internal ByteStorage (HType fromHandle)
2323
}
2424

2525
[DllImport ("caffe2")]
26-
extern static HType THByteStorage_new_wiTHsize (IntPtr size);
26+
extern static HType THByteStorage_new_withSize (IntPtr size);
2727

2828
public ByteStorage (long size)
2929
{
30-
handle = THByteStorage_new_wiTHsize ((IntPtr) size);
30+
handle = THByteStorage_new_withSize ((IntPtr) size);
3131
}
3232

3333
~ByteStorage ()
@@ -395,7 +395,7 @@ public byte this [long x0] {
395395

396396
[DllImport ("caffe2")]
397397
extern static byte THByteTensor_randperm (HType handle, IntPtr thgenerator, long n);
398-
public void Random (THRandom source, long n)
398+
public void Random (RandomGenerator source, long n)
399399
{
400400
if (source == null)
401401
throw new ArgumentNullException (nameof (source));
@@ -404,7 +404,7 @@ public void Random (THRandom source, long n)
404404

405405
public void Random (long n)
406406
{
407-
using (var r = new THRandom ())
407+
using (var r = new RandomGenerator ())
408408
Random (r, n);
409409
}
410410

@@ -963,11 +963,11 @@ internal ShortStorage (HType fromHandle)
963963
}
964964

965965
[DllImport ("caffe2")]
966-
extern static HType THShortStorage_new_wiTHsize (IntPtr size);
966+
extern static HType THShortStorage_new_withSize (IntPtr size);
967967

968968
public ShortStorage (long size)
969969
{
970-
handle = THShortStorage_new_wiTHsize ((IntPtr) size);
970+
handle = THShortStorage_new_withSize ((IntPtr) size);
971971
}
972972

973973
~ShortStorage ()
@@ -1335,7 +1335,7 @@ public short this [long x0] {
13351335

13361336
[DllImport ("caffe2")]
13371337
extern static short THShortTensor_randperm (HType handle, IntPtr thgenerator, long n);
1338-
public void Random (THRandom source, long n)
1338+
public void Random (RandomGenerator source, long n)
13391339
{
13401340
if (source == null)
13411341
throw new ArgumentNullException (nameof (source));
@@ -1344,7 +1344,7 @@ public void Random (THRandom source, long n)
13441344

13451345
public void Random (long n)
13461346
{
1347-
using (var r = new THRandom ())
1347+
using (var r = new RandomGenerator ())
13481348
Random (r, n);
13491349
}
13501350

@@ -1903,11 +1903,11 @@ internal IntStorage (HType fromHandle)
19031903
}
19041904

19051905
[DllImport ("caffe2")]
1906-
extern static HType THIntStorage_new_wiTHsize (IntPtr size);
1906+
extern static HType THIntStorage_new_withSize (IntPtr size);
19071907

19081908
public IntStorage (long size)
19091909
{
1910-
handle = THIntStorage_new_wiTHsize ((IntPtr) size);
1910+
handle = THIntStorage_new_withSize ((IntPtr) size);
19111911
}
19121912

19131913
~IntStorage ()
@@ -2275,7 +2275,7 @@ public int this [long x0] {
22752275

22762276
[DllImport ("caffe2")]
22772277
extern static int THIntTensor_randperm (HType handle, IntPtr thgenerator, long n);
2278-
public void Random (THRandom source, long n)
2278+
public void Random (RandomGenerator source, long n)
22792279
{
22802280
if (source == null)
22812281
throw new ArgumentNullException (nameof (source));
@@ -2284,7 +2284,7 @@ public void Random (THRandom source, long n)
22842284

22852285
public void Random (long n)
22862286
{
2287-
using (var r = new THRandom ())
2287+
using (var r = new RandomGenerator ())
22882288
Random (r, n);
22892289
}
22902290

@@ -2843,11 +2843,11 @@ internal LongStorage (HType fromHandle)
28432843
}
28442844

28452845
[DllImport ("caffe2")]
2846-
extern static HType THLongStorage_new_wiTHsize (IntPtr size);
2846+
extern static HType THLongStorage_new_withSize (IntPtr size);
28472847

28482848
public LongStorage (long size)
28492849
{
2850-
handle = THLongStorage_new_wiTHsize ((IntPtr) size);
2850+
handle = THLongStorage_new_withSize ((IntPtr) size);
28512851
}
28522852

28532853
~LongStorage ()
@@ -3215,7 +3215,7 @@ public long this [long x0] {
32153215

32163216
[DllImport ("caffe2")]
32173217
extern static long THLongTensor_randperm (HType handle, IntPtr thgenerator, long n);
3218-
public void Random (THRandom source, long n)
3218+
public void Random (RandomGenerator source, long n)
32193219
{
32203220
if (source == null)
32213221
throw new ArgumentNullException (nameof (source));
@@ -3224,7 +3224,7 @@ public void Random (THRandom source, long n)
32243224

32253225
public void Random (long n)
32263226
{
3227-
using (var r = new THRandom ())
3227+
using (var r = new RandomGenerator ())
32283228
Random (r, n);
32293229
}
32303230

@@ -3783,11 +3783,11 @@ internal DoubleStorage (HType fromHandle)
37833783
}
37843784

37853785
[DllImport ("caffe2")]
3786-
extern static HType THDoubleStorage_new_wiTHsize (IntPtr size);
3786+
extern static HType THDoubleStorage_new_withSize (IntPtr size);
37873787

37883788
public DoubleStorage (long size)
37893789
{
3790-
handle = THDoubleStorage_new_wiTHsize ((IntPtr) size);
3790+
handle = THDoubleStorage_new_withSize ((IntPtr) size);
37913791
}
37923792

37933793
~DoubleStorage ()
@@ -4155,7 +4155,7 @@ public double this [long x0] {
41554155

41564156
[DllImport ("caffe2")]
41574157
extern static double THDoubleTensor_randperm (HType handle, IntPtr thgenerator, long n);
4158-
public void Random (THRandom source, long n)
4158+
public void Random (RandomGenerator source, long n)
41594159
{
41604160
if (source == null)
41614161
throw new ArgumentNullException (nameof (source));
@@ -4164,7 +4164,7 @@ public void Random (THRandom source, long n)
41644164

41654165
public void Random (long n)
41664166
{
4167-
using (var r = new THRandom ())
4167+
using (var r = new RandomGenerator ())
41684168
Random (r, n);
41694169
}
41704170

@@ -4637,11 +4637,11 @@ internal FloatStorage (HType fromHandle)
46374637
}
46384638

46394639
[DllImport ("caffe2")]
4640-
extern static HType THFloatStorage_new_wiTHsize (IntPtr size);
4640+
extern static HType THFloatStorage_new_withSize (IntPtr size);
46414641

46424642
public FloatStorage (long size)
46434643
{
4644-
handle = THFloatStorage_new_wiTHsize ((IntPtr) size);
4644+
handle = THFloatStorage_new_withSize ((IntPtr) size);
46454645
}
46464646

46474647
~FloatStorage ()
@@ -5009,7 +5009,7 @@ public float this [long x0] {
50095009

50105010
[DllImport ("caffe2")]
50115011
extern static float THFloatTensor_randperm (HType handle, IntPtr thgenerator, long n);
5012-
public void Random (THRandom source, long n)
5012+
public void Random (RandomGenerator source, long n)
50135013
{
50145014
if (source == null)
50155015
throw new ArgumentNullException (nameof (source));
@@ -5018,7 +5018,7 @@ public void Random (THRandom source, long n)
50185018

50195019
public void Random (long n)
50205020
{
5021-
using (var r = new THRandom ())
5021+
using (var r = new RandomGenerator ())
50225022
Random (r, n);
50235023
}
50245024

0 commit comments

Comments
 (0)