Skip to content

Commit e02bd94

Browse files
ENH: dpnp backend fallback to classic MKL; update random.beta (#493)
* ENH: dpnp backend fallback to classic MKL; update random.beta
1 parent 8259f82 commit e02bd94

File tree

2 files changed

+55
-12
lines changed

2 files changed

+55
-12
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,51 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525

26+
#include <mkl_vsl.h>
27+
#include <stdexcept>
28+
#include <vector>
29+
2630
#include <dpnp_iface.hpp>
31+
2732
#include "dpnp_fptr.hpp"
2833
#include "dpnp_utils.hpp"
2934
#include "queue_sycl.hpp"
3035

31-
#include <vector>
32-
3336
namespace mkl_rng = oneapi::mkl::rng;
3437
namespace mkl_blas = oneapi::mkl::blas;
3538

39+
/**
40+
* Use get/set functions to access/modify this variable
41+
*/
42+
static VSLStreamStatePtr rng_stream = nullptr;
43+
44+
static void set_rng_stream(size_t seed = 1)
45+
{
46+
if (rng_stream)
47+
{
48+
vslDeleteStream(&rng_stream);
49+
rng_stream = nullptr;
50+
}
51+
52+
vslNewStream(&rng_stream, VSL_BRNG_MT19937, seed);
53+
}
54+
55+
static VSLStreamStatePtr get_rng_stream()
56+
{
57+
if (!rng_stream)
58+
{
59+
set_rng_stream();
60+
}
61+
62+
return rng_stream;
63+
}
64+
65+
void dpnp_srand_c(size_t seed)
66+
{
67+
backend_sycl::backend_sycl_rng_engine_init(seed);
68+
set_rng_stream(seed);
69+
}
70+
3671
template <typename _DataType>
3772
void dpnp_rng_beta_c(void* result, _DataType a, _DataType b, size_t size)
3873
{
@@ -47,10 +82,24 @@ void dpnp_rng_beta_c(void* result, _DataType a, _DataType b, size_t size)
4782

4883
_DataType* result1 = reinterpret_cast<_DataType*>(result);
4984

50-
mkl_rng::beta<_DataType> distribution(a, b, displacement, scalefactor);
51-
// perform generation
52-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
53-
event_out.wait();
85+
if (dpnp_queue_is_cpu_c())
86+
{
87+
mkl_rng::beta<_DataType> distribution(a, b, displacement, scalefactor);
88+
// perform generation
89+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
90+
event_out.wait();
91+
}
92+
else
93+
{
94+
int errcode =
95+
vdRngBeta(VSL_RNG_METHOD_BETA_CJA, get_rng_stream(), size, result1, a, b, displacement, scalefactor);
96+
if (errcode != VSL_STATUS_OK)
97+
{
98+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_beta_c() failed.");
99+
}
100+
}
101+
102+
return;
54103
}
55104

56105
template <typename _DataType>
@@ -439,7 +488,6 @@ void dpnp_rng_weibull_c(void* result, double alpha, size_t size)
439488
void func_map_init_random(func_map_t& fmap)
440489
{
441490
fmap[DPNPFuncName::DPNP_FN_RNG_BETA][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_rng_beta_c<double>};
442-
fmap[DPNPFuncName::DPNP_FN_RNG_BETA][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_rng_beta_c<float>};
443491

444492
fmap[DPNPFuncName::DPNP_FN_RNG_BINOMIAL][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_rng_binomial_c<int>};
445493

dpnp/backend/src/queue_sycl.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,3 @@ size_t dpnp_queue_is_cpu_c()
151151
{
152152
return backend_sycl::backend_sycl_is_cpu();
153153
}
154-
155-
void dpnp_srand_c(size_t seed)
156-
{
157-
backend_sycl::backend_sycl_rng_engine_init(seed);
158-
}

0 commit comments

Comments
 (0)