|
| 1 | +####################################################### |
| 2 | +# Copyright (c) 2015, ArrayFire |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This file is distributed under 3-clause BSD license. |
| 6 | +# The complete license agreement can be obtained at: |
| 7 | +# http://arrayfire.com/licenses/BSD-3-Clause |
| 8 | +######################################################## |
| 9 | + |
| 10 | +""" |
| 11 | +Random engine class and functions to generate random numbers. |
| 12 | +""" |
| 13 | + |
| 14 | +from .library import * |
| 15 | +from .array import * |
| 16 | +import numbers |
| 17 | + |
| 18 | +class Random_Engine(object): |
| 19 | + """ |
| 20 | + Class to handle random number generator engines. |
| 21 | +
|
| 22 | + Parameters |
| 23 | + ---------- |
| 24 | +
|
| 25 | + engine_type : optional: RANDOME_ENGINE. default: RANDOM_ENGINE.PHILOX |
| 26 | + - Specifies the type of random engine to be created. Can be one of: |
| 27 | + - RANDOM_ENGINE.PHILOX_4X32_10 |
| 28 | + - RANDOM_ENGINE.THREEFRY_2X32_16 |
| 29 | + - RANDOM_ENGINE.MERSENNE_GP11213 |
| 30 | + - RANDOM_ENGINE.PHILOX (same as RANDOM_ENGINE.PHILOX_4X32_10) |
| 31 | + - RANDOM_ENGINE.THREEFRY (same as RANDOM_ENGINE.THREEFRY_2X32_16) |
| 32 | + - RANDOM_ENGINE.DEFAULT |
| 33 | + - Not used if engine is not None |
| 34 | +
|
| 35 | + seed : optional int. default: 0 |
| 36 | + - Specifies the seed for the random engine |
| 37 | + - Not used if engine is not None |
| 38 | +
|
| 39 | + engine : optional ctypes.c_void_p. default: None. |
| 40 | + - Used a handle created by the C api to create the Random_Engine. |
| 41 | + """ |
| 42 | + |
| 43 | + def __init__(self, engine_type = RANDOM_ENGINE.PHILOX, seed = 0, engine = None): |
| 44 | + if (engine is None): |
| 45 | + self.engine = ct.c_void_p(0) |
| 46 | + safe_call(backend.get().af_create_random_engine(ct.pointer(self.engine), engine_type.value, ct.c_longlong(seed))) |
| 47 | + else: |
| 48 | + self.engine = engine |
| 49 | + |
| 50 | + def __del__(self): |
| 51 | + safe_call(backend.get().af_release_random_engine(self.engine)) |
| 52 | + |
| 53 | + def set_type(self, engine_type): |
| 54 | + """ |
| 55 | + Set the type of the random engine. |
| 56 | + """ |
| 57 | + safe_call(backend.get().af_random_engine_set_type(ct.pointer(self.engine), engine_type.value)) |
| 58 | + |
| 59 | + def get_type(self): |
| 60 | + """ |
| 61 | + Get the type of the random engine. |
| 62 | + """ |
| 63 | + __to_random_engine_type = [RANDOM_ENGINE.PHILOX_4X32_10, |
| 64 | + RANDOM_ENGINE.THREEFRY_2X32_16, |
| 65 | + RANDOM_ENGINE.MERSENNE_GP11213] |
| 66 | + rty = ct.c_int(RANDOM_ENGINE.PHILOX.value) |
| 67 | + safe_call(backend.get().af_random_engine_get_type(ct.pointer(rty), self.engine)) |
| 68 | + return __to_random_engine_type[rty] |
| 69 | + |
| 70 | + def set_seed(self, seed): |
| 71 | + """ |
| 72 | + Set the seed for the random engine. |
| 73 | + """ |
| 74 | + safe_call(backend.get().af_random_engine_set_seed(ct.pointer(self.engine), ct.c_longlong(seed))) |
| 75 | + |
| 76 | + def get_seed(self): |
| 77 | + """ |
| 78 | + Get the seed for the random engine. |
| 79 | + """ |
| 80 | + seed = ct.c_longlong(0) |
| 81 | + safe_call(backend.get().af_random_engine_get_seed(ct.pointer(seed), self.engine)) |
| 82 | + return seed.value |
| 83 | + |
| 84 | +def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, random_engine=None): |
| 85 | + """ |
| 86 | + Create a multi dimensional array containing values from a uniform distribution. |
| 87 | +
|
| 88 | + Parameters |
| 89 | + ---------- |
| 90 | + d0 : int. |
| 91 | + Length of first dimension. |
| 92 | +
|
| 93 | + d1 : optional: int. default: None. |
| 94 | + Length of second dimension. |
| 95 | +
|
| 96 | + d2 : optional: int. default: None. |
| 97 | + Length of third dimension. |
| 98 | +
|
| 99 | + d3 : optional: int. default: None. |
| 100 | + Length of fourth dimension. |
| 101 | +
|
| 102 | + dtype : optional: af.Dtype. default: af.Dtype.f32. |
| 103 | + Data type of the array. |
| 104 | +
|
| 105 | + random_engine : optional: Random_Engine. default: None. |
| 106 | + If random_engine is None, uses a default engine created by arrayfire. |
| 107 | +
|
| 108 | + Returns |
| 109 | + ------- |
| 110 | +
|
| 111 | + out : af.Array |
| 112 | + Multi dimensional array whose elements are sampled uniformly between [0, 1]. |
| 113 | + - If d1 is None, `out` is 1D of size (d0,). |
| 114 | + - If d1 is not None and d2 is None, `out` is 2D of size (d0, d1). |
| 115 | + - If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2). |
| 116 | + - If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3). |
| 117 | + """ |
| 118 | + out = Array() |
| 119 | + dims = dim4(d0, d1, d2, d3) |
| 120 | + |
| 121 | + if random_engine is None: |
| 122 | + safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value)) |
| 123 | + else: |
| 124 | + safe_call(backend.get().af_random_uniform(ct.pointer(out.arr), 4, ct.pointer(dims), random_engine.engine)) |
| 125 | + |
| 126 | + return out |
| 127 | + |
| 128 | +def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, random_engine=None): |
| 129 | + """ |
| 130 | + Create a multi dimensional array containing values from a normal distribution. |
| 131 | +
|
| 132 | + Parameters |
| 133 | + ---------- |
| 134 | + d0 : int. |
| 135 | + Length of first dimension. |
| 136 | +
|
| 137 | + d1 : optional: int. default: None. |
| 138 | + Length of second dimension. |
| 139 | +
|
| 140 | + d2 : optional: int. default: None. |
| 141 | + Length of third dimension. |
| 142 | +
|
| 143 | + d3 : optional: int. default: None. |
| 144 | + Length of fourth dimension. |
| 145 | +
|
| 146 | + dtype : optional: af.Dtype. default: af.Dtype.f32. |
| 147 | + Data type of the array. |
| 148 | +
|
| 149 | + random_engine : optional: Random_Engine. default: None. |
| 150 | + If random_engine is None, uses a default engine created by arrayfire. |
| 151 | +
|
| 152 | + Returns |
| 153 | + ------- |
| 154 | +
|
| 155 | + out : af.Array |
| 156 | + Multi dimensional array whose elements are sampled from a normal distribution with mean 0 and sigma of 1. |
| 157 | + - If d1 is None, `out` is 1D of size (d0,). |
| 158 | + - If d1 is not None and d2 is None, `out` is 2D of size (d0, d1). |
| 159 | + - If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2). |
| 160 | + - If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3). |
| 161 | + """ |
| 162 | + |
| 163 | + out = Array() |
| 164 | + dims = dim4(d0, d1, d2, d3) |
| 165 | + |
| 166 | + if random_engine is None: |
| 167 | + safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value)) |
| 168 | + else: |
| 169 | + safe_call(backend.get().af_random_normal(ct.pointer(out.arr), 4, ct.pointer(dims), random_engine.engine)) |
| 170 | + |
| 171 | + return out |
| 172 | + |
| 173 | +def set_seed(seed=0): |
| 174 | + """ |
| 175 | + Set the seed for the random number generator. |
| 176 | +
|
| 177 | + Parameters |
| 178 | + ---------- |
| 179 | + seed: int. |
| 180 | + Seed for the random number generator |
| 181 | + """ |
| 182 | + safe_call(backend.get().af_set_seed(ct.c_ulonglong(seed))) |
| 183 | + |
| 184 | +def get_seed(): |
| 185 | + """ |
| 186 | + Get the seed for the random number generator. |
| 187 | +
|
| 188 | + Returns |
| 189 | + ---------- |
| 190 | + seed: int. |
| 191 | + Seed for the random number generator |
| 192 | + """ |
| 193 | + seed = ct.c_ulonglong(0) |
| 194 | + safe_call(backend.get().af_get_seed(ct.pointer(seed))) |
| 195 | + return seed.value |
| 196 | + |
| 197 | +def set_default_random_engine_type(engine_type): |
| 198 | + """ |
| 199 | + Set random engine type for default random engine. |
| 200 | +
|
| 201 | + Parameters |
| 202 | + ---------- |
| 203 | + engine_type : RANDOME_ENGINE. |
| 204 | + - Specifies the type of random engine to be created. Can be one of: |
| 205 | + - RANDOM_ENGINE.PHILOX_4X32_10 |
| 206 | + - RANDOM_ENGINE.THREEFRY_2X32_16 |
| 207 | + - RANDOM_ENGINE.MERSENNE_GP11213 |
| 208 | + - RANDOM_ENGINE.PHILOX (same as RANDOM_ENGINE.PHILOX_4X32_10) |
| 209 | + - RANDOM_ENGINE.THREEFRY (same as RANDOM_ENGINE.THREEFRY_2X32_16) |
| 210 | + - RANDOM_ENGINE.DEFAULT |
| 211 | +
|
| 212 | + Note |
| 213 | + ---- |
| 214 | +
|
| 215 | + This only affects randu and randn when a random engine is not specified. |
| 216 | + """ |
| 217 | + safe_call(backend.get().af_set_default_random_engine_type(ct.pointer(self.engine), engine_type.value)) |
| 218 | + |
| 219 | +def get_default_random_engine(): |
| 220 | + """ |
| 221 | + Get the default random engine |
| 222 | +
|
| 223 | + Returns |
| 224 | + ------ |
| 225 | +
|
| 226 | + The default random engine used by randu and randn |
| 227 | + """ |
| 228 | + engine = ct.c_void_p(0) |
| 229 | + default_engine = ct.c_void_p(0) |
| 230 | + safe_call(backend.get().af_get_default_random_engine(ct.pointer(default_engine))) |
| 231 | + safe_call(backend.get().af_retain_random_engine(ct.pointer(engine), default_engine)) |
| 232 | + return Random_Engine(engine=engine) |
0 commit comments