Skip to content

Commit 4248b84

Browse files
MAINT: refactor random.noncentral_chisquare (#730)
1 parent 9ea2351 commit 4248b84

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ void dpnp_rng_negative_binomial_c(void* result, const double a, const double p,
600600
template <typename _DataType>
601601
void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _DataType nonc, const size_t size)
602602
{
603-
if (!size)
603+
if (!size || !result)
604604
{
605605
return;
606606
}
@@ -614,8 +614,6 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
614614
{
615615
_DataType shape, loc;
616616
size_t i;
617-
cl::sycl::event event_out;
618-
cl::sycl::vector_class<cl::sycl::event> no_deps;
619617

620618
if (df > 1)
621619
{
@@ -624,23 +622,20 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
624622
shape = 0.5 * (df - 1.0);
625623
/* res has chi^2 with (df - 1) */
626624
mkl_rng::gamma<_DataType> gamma_distribution(shape, d_zero, d_two);
627-
event_out = mkl_rng::generate(gamma_distribution, DPNP_RNG_ENGINE, size, result1);
628-
event_out.wait();
625+
auto event_gamma_distr = mkl_rng::generate(gamma_distribution, DPNP_RNG_ENGINE, size, result1);
629626

630627
nvec = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(size * sizeof(_DataType)));
631628

632629
loc = sqrt(nonc);
633630

634631
mkl_rng::gaussian<_DataType> gaussian_distribution(loc, d_one);
635-
event_out = mkl_rng::generate(gaussian_distribution, DPNP_RNG_ENGINE, size, nvec);
636-
event_out.wait();
632+
auto event_gaussian_distr = mkl_rng::generate(gaussian_distribution, DPNP_RNG_ENGINE, size, nvec);
637633

638634
/* squaring could result in an overflow */
639-
event_out = mkl_vm::sqr(DPNP_QUEUE, size, nvec, nvec, no_deps, mkl_vm::mode::ha);
640-
event_out.wait();
641-
event_out = mkl_vm::add(DPNP_QUEUE, size, result1, nvec, result1, no_deps, mkl_vm::mode::ha);
635+
auto event_sqr_out = mkl_vm::sqr(DPNP_QUEUE, size, nvec, nvec, {event_gamma_distr, event_gaussian_distr}, mkl_vm::mode::ha);
636+
auto event_add_out = mkl_vm::add(DPNP_QUEUE, size, result1, nvec, result1, {event_sqr_out}, mkl_vm::mode::ha);
642637
dpnp_memory_free_c(nvec);
643-
event_out.wait();
638+
event_add_out.wait();
644639
}
645640
else if (df < 1)
646641
{
@@ -651,7 +646,7 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
651646
lambda = 0.5 * nonc;
652647

653648
mkl_rng::poisson<int> poisson_distribution(lambda);
654-
event_out = mkl_rng::generate(poisson_distribution, DPNP_RNG_ENGINE, size, pvec);
649+
auto event_out = mkl_rng::generate(poisson_distribution, DPNP_RNG_ENGINE, size, pvec);
655650
event_out.wait();
656651

657652
shape = 0.5 * df;
@@ -713,9 +708,8 @@ void dpnp_rng_noncentral_chisquare_c(void* result, const _DataType df, const _Da
713708
/* noncentral_chisquare(1, nonc) ~ (Z + sqrt(nonc))**2 for df == 1 */
714709
loc = sqrt(nonc);
715710
mkl_rng::gaussian<_DataType> gaussian_distribution(loc, d_one);
716-
event_out = mkl_rng::generate(gaussian_distribution, DPNP_RNG_ENGINE, size, result1);
717-
event_out.wait();
718-
event_out = mkl_vm::sqr(DPNP_QUEUE, size, result1, result1, no_deps, mkl_vm::mode::ha);
711+
auto event_gaussian_distr = mkl_rng::generate(gaussian_distribution, DPNP_RNG_ENGINE, size, result1);
712+
auto event_out = mkl_vm::sqr(DPNP_QUEUE, size, result1, result1, {event_gaussian_distr}, mkl_vm::mode::ha);
719713
event_out.wait();
720714
}
721715
}

0 commit comments

Comments
 (0)