Skip to content

Commit d4f31e6

Browse files
ENH: dpnp backend fallback to classic MKL for random (#498)
* ENH: dpnp backend fallback to classic MKL for random
1 parent 67d4699 commit d4f31e6

File tree

1 file changed

+115
-33
lines changed

1 file changed

+115
-33
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 115 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ void dpnp_rng_beta_c(void* result, _DataType a, _DataType b, size_t size)
9191
}
9292
else
9393
{
94-
int errcode =
95-
vdRngBeta(VSL_RNG_METHOD_BETA_CJA, get_rng_stream(), size, result1, a, b, displacement, scalefactor);
94+
int errcode = vdRngBeta(VSL_RNG_METHOD_BETA_CJA, get_rng_stream(), size,
95+
result1, a, b, displacement, scalefactor);
9696
if (errcode != VSL_STATUS_OK)
9797
{
9898
throw std::runtime_error("DPNP RNG Error: dpnp_rng_beta_c() failed.");
@@ -111,10 +111,22 @@ void dpnp_rng_binomial_c(void* result, int ntrial, double p, size_t size)
111111
}
112112
_DataType* result1 = reinterpret_cast<_DataType*>(result);
113113

114-
mkl_rng::binomial<_DataType> distribution(ntrial, p);
115-
// perform generation
116-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
117-
event_out.wait();
114+
if (dpnp_queue_is_cpu_c())
115+
{
116+
mkl_rng::binomial<_DataType> distribution(ntrial, p);
117+
// perform generation
118+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
119+
event_out.wait();
120+
}
121+
else
122+
{
123+
int errcode = viRngBinomial(VSL_RNG_METHOD_BINOMIAL_BTPE, get_rng_stream(),
124+
size, result1, ntrial, p);
125+
if (errcode != VSL_STATUS_OK)
126+
{
127+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_binomial_c() failed.");
128+
}
129+
}
118130
}
119131

120132
template <typename _DataType>
@@ -126,10 +138,22 @@ void dpnp_rng_chi_square_c(void* result, int df, size_t size)
126138
}
127139
_DataType* result1 = reinterpret_cast<_DataType*>(result);
128140

129-
mkl_rng::chi_square<_DataType> distribution(df);
130-
// perform generation
131-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
132-
event_out.wait();
141+
if (dpnp_queue_is_cpu_c())
142+
{
143+
mkl_rng::chi_square<_DataType> distribution(df);
144+
// perform generation
145+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
146+
event_out.wait();
147+
}
148+
else
149+
{
150+
int errcode = vdRngChiSquare(VSL_RNG_METHOD_CHISQUARE_CHI2GAMMA, get_rng_stream(),
151+
size, result1, df);
152+
if (errcode != VSL_STATUS_OK)
153+
{
154+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_chi_square_c() failed.");
155+
}
156+
}
133157
}
134158

135159
template <typename _DataType>
@@ -164,10 +188,22 @@ void dpnp_rng_gamma_c(void* result, _DataType shape, _DataType scale, size_t siz
164188

165189
_DataType* result1 = reinterpret_cast<_DataType*>(result);
166190

167-
mkl_rng::gamma<_DataType> distribution(shape, a, scale);
168-
// perform generation
169-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
170-
event_out.wait();
191+
if (dpnp_queue_is_cpu_c())
192+
{
193+
mkl_rng::gamma<_DataType> distribution(shape, a, scale);
194+
// perform generation
195+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
196+
event_out.wait();
197+
}
198+
else
199+
{
200+
int errcode = vdRngGamma(VSL_RNG_METHOD_GAMMA_GNORM, get_rng_stream(), size,
201+
result1, shape, a, scale);
202+
if (errcode != VSL_STATUS_OK)
203+
{
204+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_gamma_c() failed.");
205+
}
206+
}
171207
}
172208

173209
template <typename _DataType>
@@ -231,10 +267,22 @@ void dpnp_rng_hypergeometric_c(void* result, int l, int s, int m, size_t size)
231267
}
232268
_DataType* result1 = reinterpret_cast<_DataType*>(result);
233269

234-
mkl_rng::hypergeometric<_DataType> distribution(l, s, m);
235-
// perform generation
236-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
237-
event_out.wait();
270+
if (dpnp_queue_is_cpu_c())
271+
{
272+
mkl_rng::hypergeometric<_DataType> distribution(l, s, m);
273+
// perform generation
274+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
275+
event_out.wait();
276+
}
277+
else
278+
{
279+
int errcode = viRngHypergeometric(VSL_RNG_METHOD_HYPERGEOMETRIC_H2PE, get_rng_stream(),
280+
size, result1, l, s, m);
281+
if (errcode != VSL_STATUS_OK)
282+
{
283+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_hypergeometric_c() failed.");
284+
}
285+
}
238286
}
239287

240288
template <typename _DataType>
@@ -272,24 +320,40 @@ void dpnp_rng_lognormal_c(void* result, _DataType mean, _DataType stddev, size_t
272320
}
273321

274322
template <typename _DataType>
275-
void dpnp_rng_multinomial_c(void* result, int ntrial, const double* p_vector, const size_t p_vector_size, size_t size)
323+
void dpnp_rng_multinomial_c(void* result,
324+
int ntrial,
325+
const double* p_vector,
326+
const size_t p_vector_size,
327+
size_t size)
276328
{
277329
if (!size)
278330
{
279331
return;
280332
}
281333
std::int32_t* result1 = reinterpret_cast<std::int32_t*>(result);
282334
std::vector<double> p(p_vector, p_vector + p_vector_size);
283-
284-
mkl_rng::multinomial<std::int32_t> distribution(ntrial, p);
285335
// size = size
286336
// `result` is a array for random numbers
287337
// `size` is a `result`'s len. `size = n * p.size()`
288338
// `n` is a number of random values to be generated.
289339
size_t n = size / p.size();
290-
// perform generation
291-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, n, result1);
292-
event_out.wait();
340+
341+
if (dpnp_queue_is_cpu_c())
342+
{
343+
mkl_rng::multinomial<std::int32_t> distribution(ntrial, p);
344+
// perform generation
345+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, n, result1);
346+
event_out.wait();
347+
}
348+
else
349+
{
350+
int errcode = viRngMultinomial(VSL_RNG_METHOD_MULTINOMIAL_MULTPOISSON, get_rng_stream(),
351+
n, result1, ntrial, p_vector_size, p_vector);
352+
if (errcode != VSL_STATUS_OK)
353+
{
354+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_multinomial_c() failed.");
355+
}
356+
}
293357
}
294358

295359
template <typename _DataType>
@@ -313,11 +377,19 @@ void dpnp_rng_multivariate_normal_c(void* result,
313377
// `result` is a array for random numbers
314378
// `size` is a `result`'s len.
315379
// `size1` is a number of random values to be generated for each dimension.
316-
mkl_rng::gaussian_mv<_DataType> distribution(dimen, mean, cov);
317380
size_t size1 = size / dimen;
318381

319-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size1, result1);
320-
event_out.wait();
382+
if (dpnp_queue_is_cpu_c())
383+
{
384+
mkl_rng::gaussian_mv<_DataType> distribution(dimen, mean, cov);
385+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size1, result1);
386+
event_out.wait();
387+
}
388+
else
389+
{
390+
int errcode = vdRngGaussianMV(VSL_RNG_METHOD_GAUSSIANMV_BOXMULLER2, get_rng_stream(),
391+
size1, result1, dimen, VSL_MATRIX_STORAGE_FULL, mean_vector, cov_vector );
392+
}
321393
}
322394

323395
template <typename _DataType>
@@ -329,10 +401,22 @@ void dpnp_rng_negative_binomial_c(void* result, double a, double p, size_t size)
329401
}
330402
_DataType* result1 = reinterpret_cast<_DataType*>(result);
331403

332-
mkl_rng::negative_binomial<_DataType> distribution(a, p);
333-
// perform generation
334-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
335-
event_out.wait();
404+
if (dpnp_queue_is_cpu_c())
405+
{
406+
mkl_rng::negative_binomial<_DataType> distribution(a, p);
407+
// perform generation
408+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
409+
event_out.wait();
410+
}
411+
else
412+
{
413+
int errcode = viRngNegbinomial(VSL_RNG_METHOD_NEGBINOMIAL_NBAR, get_rng_stream(),
414+
size, result1, a, p);
415+
if (errcode != VSL_STATUS_OK)
416+
{
417+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_negative_binomial_c() failed.");
418+
}
419+
}
336420
}
337421

338422
template <typename _DataType>
@@ -492,13 +576,11 @@ void func_map_init_random(func_map_t& fmap)
492576
fmap[DPNPFuncName::DPNP_FN_RNG_BINOMIAL][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_rng_binomial_c<int>};
493577

494578
fmap[DPNPFuncName::DPNP_FN_RNG_CHISQUARE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_rng_chi_square_c<double>};
495-
fmap[DPNPFuncName::DPNP_FN_RNG_CHISQUARE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_rng_chi_square_c<float>};
496579

497580
fmap[DPNPFuncName::DPNP_FN_RNG_EXPONENTIAL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_rng_exponential_c<double>};
498581
fmap[DPNPFuncName::DPNP_FN_RNG_EXPONENTIAL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_rng_exponential_c<float>};
499582

500583
fmap[DPNPFuncName::DPNP_FN_RNG_GAMMA][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_rng_gamma_c<double>};
501-
fmap[DPNPFuncName::DPNP_FN_RNG_GAMMA][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_rng_gamma_c<float>};
502584

503585
fmap[DPNPFuncName::DPNP_FN_RNG_GAUSSIAN][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_rng_gaussian_c<double>};
504586
fmap[DPNPFuncName::DPNP_FN_RNG_GAUSSIAN][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_rng_gaussian_c<float>};

0 commit comments

Comments
 (0)