Skip to content

Commit fe18684

Browse files
committed
Adding random engine class and relavent methods
- Also added necessary tests
1 parent 7594769 commit fe18684

File tree

5 files changed

+273
-109
lines changed

5 files changed

+273
-109
lines changed

arrayfire/data.py

Lines changed: 1 addition & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .array import *
1717
from .util import *
1818
from .util import _is_number
19+
from .random import randu, randn, set_seed, get_seed
1920

2021
def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
2122
"""
@@ -186,105 +187,6 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32)
186187
4, ct.pointer(tdims), dtype.value))
187188
return out
188189

189-
def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
190-
"""
191-
Create a multi dimensional array containing values from a uniform distribution.
192-
193-
Parameters
194-
----------
195-
d0 : int.
196-
Length of first dimension.
197-
198-
d1 : optional: int. default: None.
199-
Length of second dimension.
200-
201-
d2 : optional: int. default: None.
202-
Length of third dimension.
203-
204-
d3 : optional: int. default: None.
205-
Length of fourth dimension.
206-
207-
dtype : optional: af.Dtype. default: af.Dtype.f32.
208-
Data type of the array.
209-
210-
Returns
211-
-------
212-
213-
out : af.Array
214-
Multi dimensional array whose elements are sampled uniformly between [0, 1].
215-
- If d1 is None, `out` is 1D of size (d0,).
216-
- If d1 is not None and d2 is None, `out` is 2D of size (d0, d1).
217-
- If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2).
218-
- If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3).
219-
"""
220-
out = Array()
221-
dims = dim4(d0, d1, d2, d3)
222-
223-
safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
224-
return out
225-
226-
def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
227-
"""
228-
Create a multi dimensional array containing values from a normal distribution.
229-
230-
Parameters
231-
----------
232-
d0 : int.
233-
Length of first dimension.
234-
235-
d1 : optional: int. default: None.
236-
Length of second dimension.
237-
238-
d2 : optional: int. default: None.
239-
Length of third dimension.
240-
241-
d3 : optional: int. default: None.
242-
Length of fourth dimension.
243-
244-
dtype : optional: af.Dtype. default: af.Dtype.f32.
245-
Data type of the array.
246-
247-
Returns
248-
-------
249-
250-
out : af.Array
251-
Multi dimensional array whose elements are sampled from a normal distribution with mean 0 and sigma of 1.
252-
- If d1 is None, `out` is 1D of size (d0,).
253-
- If d1 is not None and d2 is None, `out` is 2D of size (d0, d1).
254-
- If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2).
255-
- If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3).
256-
"""
257-
258-
out = Array()
259-
dims = dim4(d0, d1, d2, d3)
260-
261-
safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
262-
return out
263-
264-
def set_seed(seed=0):
265-
"""
266-
Set the seed for the random number generator.
267-
268-
Parameters
269-
----------
270-
seed: int.
271-
Seed for the random number generator
272-
"""
273-
safe_call(backend.get().af_set_seed(ct.c_ulonglong(seed)))
274-
275-
def get_seed():
276-
"""
277-
Get the seed for the random number generator.
278-
279-
Returns
280-
----------
281-
seed: int.
282-
Seed for the random number generator
283-
"""
284-
seed = ct.c_ulonglong(0)
285-
safe_call(backend.get().af_get_seed(ct.pointer(seed)))
286-
return seed.value
287-
288190
def identity(d0, d1, d2=None, d3=None, dtype=Dtype.f32):
289191
"""
290192
Create an identity matrix or batch of identity matrices.

arrayfire/features.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
# The complete license agreement can be obtained at:
77
# http://arrayfire.com/licenses/BSD-3-Clause
88
########################################################
9+
910
"""
1011
Features class used for Computer Vision algorithms.
1112
"""
13+
1214
from .library import *
1315
from .array import *
1416
import numbers

arrayfire/random.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
#######################################################
2+
# Copyright (c) 2015, ArrayFire
3+
# All rights reserved.
4+
#
5+
# This file is distributed under 3-clause BSD license.
6+
# The complete license agreement can be obtained at:
7+
# http://arrayfire.com/licenses/BSD-3-Clause
8+
########################################################
9+
10+
"""
11+
Random engine class and functions to generate random numbers.
12+
"""
13+
14+
from .library import *
15+
from .array import *
16+
import numbers
17+
18+
class Random_Engine(object):
19+
"""
20+
Class to handle random number generator engines.
21+
22+
Parameters
23+
----------
24+
25+
engine_type : optional: RANDOME_ENGINE. default: RANDOM_ENGINE.PHILOX
26+
- Specifies the type of random engine to be created. Can be one of:
27+
- RANDOM_ENGINE.PHILOX_4X32_10
28+
- RANDOM_ENGINE.THREEFRY_2X32_16
29+
- RANDOM_ENGINE.MERSENNE_GP11213
30+
- RANDOM_ENGINE.PHILOX (same as RANDOM_ENGINE.PHILOX_4X32_10)
31+
- RANDOM_ENGINE.THREEFRY (same as RANDOM_ENGINE.THREEFRY_2X32_16)
32+
- RANDOM_ENGINE.DEFAULT
33+
- Not used if engine is not None
34+
35+
seed : optional int. default: 0
36+
- Specifies the seed for the random engine
37+
- Not used if engine is not None
38+
39+
engine : optional ctypes.c_void_p. default: None.
40+
- Used a handle created by the C api to create the Random_Engine.
41+
"""
42+
43+
def __init__(self, engine_type = RANDOM_ENGINE.PHILOX, seed = 0, engine = None):
44+
if (engine is None):
45+
self.engine = ct.c_void_p(0)
46+
safe_call(backend.get().af_create_random_engine(ct.pointer(self.engine), engine_type.value, ct.c_longlong(seed)))
47+
else:
48+
self.engine = engine
49+
50+
def __del__(self):
51+
safe_call(backend.get().af_release_random_engine(self.engine))
52+
53+
def set_type(self, engine_type):
54+
"""
55+
Set the type of the random engine.
56+
"""
57+
safe_call(backend.get().af_random_engine_set_type(ct.pointer(self.engine), engine_type.value))
58+
59+
def get_type(self):
60+
"""
61+
Get the type of the random engine.
62+
"""
63+
__to_random_engine_type = [RANDOM_ENGINE.PHILOX_4X32_10,
64+
RANDOM_ENGINE.THREEFRY_2X32_16,
65+
RANDOM_ENGINE.MERSENNE_GP11213]
66+
rty = ct.c_int(RANDOM_ENGINE.PHILOX.value)
67+
safe_call(backend.get().af_random_engine_get_type(ct.pointer(rty), self.engine))
68+
return __to_random_engine_type[rty]
69+
70+
def set_seed(self, seed):
71+
"""
72+
Set the seed for the random engine.
73+
"""
74+
safe_call(backend.get().af_random_engine_set_seed(ct.pointer(self.engine), ct.c_longlong(seed)))
75+
76+
def get_seed(self):
77+
"""
78+
Get the seed for the random engine.
79+
"""
80+
seed = ct.c_longlong(0)
81+
safe_call(backend.get().af_random_engine_get_seed(ct.pointer(seed), self.engine))
82+
return seed.value
83+
84+
def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, random_engine=None):
85+
"""
86+
Create a multi dimensional array containing values from a uniform distribution.
87+
88+
Parameters
89+
----------
90+
d0 : int.
91+
Length of first dimension.
92+
93+
d1 : optional: int. default: None.
94+
Length of second dimension.
95+
96+
d2 : optional: int. default: None.
97+
Length of third dimension.
98+
99+
d3 : optional: int. default: None.
100+
Length of fourth dimension.
101+
102+
dtype : optional: af.Dtype. default: af.Dtype.f32.
103+
Data type of the array.
104+
105+
random_engine : optional: Random_Engine. default: None.
106+
If random_engine is None, uses a default engine created by arrayfire.
107+
108+
Returns
109+
-------
110+
111+
out : af.Array
112+
Multi dimensional array whose elements are sampled uniformly between [0, 1].
113+
- If d1 is None, `out` is 1D of size (d0,).
114+
- If d1 is not None and d2 is None, `out` is 2D of size (d0, d1).
115+
- If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2).
116+
- If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3).
117+
"""
118+
out = Array()
119+
dims = dim4(d0, d1, d2, d3)
120+
121+
if random_engine is None:
122+
safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
123+
else:
124+
safe_call(backend.get().af_random_uniform(ct.pointer(out.arr), 4, ct.pointer(dims), random_engine.engine))
125+
126+
return out
127+
128+
def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, random_engine=None):
129+
"""
130+
Create a multi dimensional array containing values from a normal distribution.
131+
132+
Parameters
133+
----------
134+
d0 : int.
135+
Length of first dimension.
136+
137+
d1 : optional: int. default: None.
138+
Length of second dimension.
139+
140+
d2 : optional: int. default: None.
141+
Length of third dimension.
142+
143+
d3 : optional: int. default: None.
144+
Length of fourth dimension.
145+
146+
dtype : optional: af.Dtype. default: af.Dtype.f32.
147+
Data type of the array.
148+
149+
random_engine : optional: Random_Engine. default: None.
150+
If random_engine is None, uses a default engine created by arrayfire.
151+
152+
Returns
153+
-------
154+
155+
out : af.Array
156+
Multi dimensional array whose elements are sampled from a normal distribution with mean 0 and sigma of 1.
157+
- If d1 is None, `out` is 1D of size (d0,).
158+
- If d1 is not None and d2 is None, `out` is 2D of size (d0, d1).
159+
- If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2).
160+
- If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3).
161+
"""
162+
163+
out = Array()
164+
dims = dim4(d0, d1, d2, d3)
165+
166+
if random_engine is None:
167+
safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
168+
else:
169+
safe_call(backend.get().af_random_normal(ct.pointer(out.arr), 4, ct.pointer(dims), random_engine.engine))
170+
171+
return out
172+
173+
def set_seed(seed=0):
174+
"""
175+
Set the seed for the random number generator.
176+
177+
Parameters
178+
----------
179+
seed: int.
180+
Seed for the random number generator
181+
"""
182+
safe_call(backend.get().af_set_seed(ct.c_ulonglong(seed)))
183+
184+
def get_seed():
185+
"""
186+
Get the seed for the random number generator.
187+
188+
Returns
189+
----------
190+
seed: int.
191+
Seed for the random number generator
192+
"""
193+
seed = ct.c_ulonglong(0)
194+
safe_call(backend.get().af_get_seed(ct.pointer(seed)))
195+
return seed.value
196+
197+
def set_default_random_engine_type(engine_type):
198+
"""
199+
Set random engine type for default random engine.
200+
201+
Parameters
202+
----------
203+
engine_type : RANDOME_ENGINE.
204+
- Specifies the type of random engine to be created. Can be one of:
205+
- RANDOM_ENGINE.PHILOX_4X32_10
206+
- RANDOM_ENGINE.THREEFRY_2X32_16
207+
- RANDOM_ENGINE.MERSENNE_GP11213
208+
- RANDOM_ENGINE.PHILOX (same as RANDOM_ENGINE.PHILOX_4X32_10)
209+
- RANDOM_ENGINE.THREEFRY (same as RANDOM_ENGINE.THREEFRY_2X32_16)
210+
- RANDOM_ENGINE.DEFAULT
211+
212+
Note
213+
----
214+
215+
This only affects randu and randn when a random engine is not specified.
216+
"""
217+
safe_call(backend.get().af_set_default_random_engine_type(ct.pointer(self.engine), engine_type.value))
218+
219+
def get_default_random_engine():
220+
"""
221+
Get the default random engine
222+
223+
Returns
224+
------
225+
226+
The default random engine used by randu and randn
227+
"""
228+
engine = ct.c_void_p(0)
229+
default_engine = ct.c_void_p(0)
230+
safe_call(backend.get().af_get_default_random_engine(ct.pointer(default_engine)))
231+
safe_call(backend.get().af_retain_random_engine(ct.pointer(engine), default_engine))
232+
return Random_Engine(engine=engine)

arrayfire/tests/simple/data.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,6 @@ def simple_data(verbose=False):
2424
display_func(af.range(3, 3))
2525
display_func(af.iota(3, 3, tile_dims=(2,2)))
2626

27-
display_func(af.randu(3, 3, 1, 2))
28-
display_func(af.randu(3, 3, 1, 2, af.Dtype.b8))
29-
display_func(af.randu(3, 3, dtype=af.Dtype.c32))
30-
31-
display_func(af.randn(3, 3, 1, 2))
32-
display_func(af.randn(3, 3, dtype=af.Dtype.c32))
33-
34-
af.set_seed(1024)
35-
assert(af.get_seed() == 1024)
36-
3727
display_func(af.identity(3, 3, 1, 2, af.Dtype.b8))
3828
display_func(af.identity(3, 3, dtype=af.Dtype.c32))
3929

0 commit comments

Comments
 (0)