23
23
// THE POSSIBILITY OF SUCH DAMAGE.
24
24
// *****************************************************************************
25
25
26
+ #include < mkl_vsl.h>
27
+ #include < stdexcept>
28
+ #include < vector>
29
+
26
30
#include < dpnp_iface.hpp>
31
+
27
32
#include " dpnp_fptr.hpp"
28
33
#include " dpnp_utils.hpp"
29
34
#include " queue_sycl.hpp"
30
35
31
- #include < vector>
32
-
33
36
namespace mkl_rng = oneapi::mkl::rng;
34
37
namespace mkl_blas = oneapi::mkl::blas;
35
38
39
+ /* *
40
+ * Use get/set functions to access/modify this variable
41
+ */
42
+ static VSLStreamStatePtr rng_stream = nullptr ;
43
+
44
+ static void set_rng_stream (size_t seed = 1 )
45
+ {
46
+ if (rng_stream)
47
+ {
48
+ vslDeleteStream (&rng_stream);
49
+ rng_stream = nullptr ;
50
+ }
51
+
52
+ vslNewStream (&rng_stream, VSL_BRNG_MT19937, seed);
53
+ }
54
+
55
+ static VSLStreamStatePtr get_rng_stream ()
56
+ {
57
+ if (!rng_stream)
58
+ {
59
+ set_rng_stream ();
60
+ }
61
+
62
+ return rng_stream;
63
+ }
64
+
65
+ void dpnp_srand_c (size_t seed)
66
+ {
67
+ backend_sycl::backend_sycl_rng_engine_init (seed);
68
+ set_rng_stream (seed);
69
+ }
70
+
36
71
template <typename _DataType>
37
72
void dpnp_rng_beta_c (void * result, _DataType a, _DataType b, size_t size)
38
73
{
@@ -47,10 +82,24 @@ void dpnp_rng_beta_c(void* result, _DataType a, _DataType b, size_t size)
47
82
48
83
_DataType* result1 = reinterpret_cast <_DataType*>(result);
49
84
50
- mkl_rng::beta<_DataType> distribution (a, b, displacement, scalefactor);
51
- // perform generation
52
- auto event_out = mkl_rng::generate (distribution, DPNP_RNG_ENGINE, size, result1);
53
- event_out.wait ();
85
+ if (dpnp_queue_is_cpu_c ())
86
+ {
87
+ mkl_rng::beta<_DataType> distribution (a, b, displacement, scalefactor);
88
+ // perform generation
89
+ auto event_out = mkl_rng::generate (distribution, DPNP_RNG_ENGINE, size, result1);
90
+ event_out.wait ();
91
+ }
92
+ else
93
+ {
94
+ int errcode =
95
+ vdRngBeta (VSL_RNG_METHOD_BETA_CJA, get_rng_stream (), size, result1, a, b, displacement, scalefactor);
96
+ if (errcode != VSL_STATUS_OK)
97
+ {
98
+ throw std::runtime_error (" DPNP RNG Error: dpnp_rng_beta_c() failed." );
99
+ }
100
+ }
101
+
102
+ return ;
54
103
}
55
104
56
105
template <typename _DataType>
@@ -439,7 +488,6 @@ void dpnp_rng_weibull_c(void* result, double alpha, size_t size)
439
488
void func_map_init_random (func_map_t & fmap)
440
489
{
441
490
fmap[DPNPFuncName::DPNP_FN_RNG_BETA][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_rng_beta_c<double >};
442
- fmap[DPNPFuncName::DPNP_FN_RNG_BETA][eft_FLT][eft_FLT] = {eft_FLT, (void *)dpnp_rng_beta_c<float >};
443
491
444
492
fmap[DPNPFuncName::DPNP_FN_RNG_BINOMIAL][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_rng_binomial_c<int >};
445
493
0 commit comments