Skip to content

Commit 00f325f

Browse files
FIX: rayleigh dist (#513)
1 parent b0f6461 commit 00f325f

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
#include "dpnp_utils.hpp"
3434
#include "queue_sycl.hpp"
3535

36-
namespace mkl_rng = oneapi::mkl::rng;
3736
namespace mkl_blas = oneapi::mkl::blas;
37+
namespace mkl_rng = oneapi::mkl::rng;
38+
namespace mkl_vm = oneapi::mkl::vm;
3839

3940
/**
4041
* Use get/set functions to access/modify this variable
@@ -461,15 +462,23 @@ void dpnp_rng_rayleigh_c(void* result, _DataType scale, size_t size)
461462
return;
462463
}
463464

464-
// set displacement a
465-
const _DataType a = (_DataType(0.0));
465+
cl::sycl::vector_class<cl::sycl::event> no_deps;
466+
467+
const _DataType a = 0.0;
468+
const _DataType beta = 2.0;
466469

467470
_DataType* result1 = reinterpret_cast<_DataType*>(result);
468471

469-
mkl_rng::rayleigh<_DataType> distribution(a, scale);
470-
// perform generation
472+
mkl_rng::exponential<_DataType> distribution(a, beta);;
473+
471474
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
472475
event_out.wait();
476+
event_out = mkl_vm::sqrt(DPNP_QUEUE, size, result1, result1, no_deps, mkl_vm::mode::ha);
477+
event_out.wait();
478+
// with MKL
479+
// event_out = mkl_blas::axpy(DPNP_QUEUE, size, scale, result1, 1, result1, 1);
480+
// event_out.wait();
481+
for(size_t i = 0; i < size; i++) result1[i] *= scale;
473482
}
474483

475484
template <typename _DataType>

tests/skipped_tests.tbl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ tests/test_linalg.py::test_norm3[(1, 2)-3-[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]
139139
tests/test_linalg.py::test_norm3[(1, 2)-3-[[[1, 0], [3, 0]], [[5, 0], [7, 0]]]]
140140
tests/test_linalg.py::test_svd[(3,4)-complex128]
141141
tests/test_linalg.py::test_svd[(5,3)-complex128]
142-
tests/test_random.py::TestDistributionsRayleigh::test_moments
143142
tests/third_party/cupy/binary_tests/test_elementwise.py::TestElementwise::test_bitwise_and
144143
tests/third_party/cupy/binary_tests/test_elementwise.py::TestElementwise::test_bitwise_or
145144
tests/third_party/cupy/binary_tests/test_elementwise.py::TestElementwise::test_bitwise_xor

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ tests/test_linalg.py::test_svd[(16,16)-float64]
156156
tests/test_linalg.py::test_svd[(16,16)-float32]
157157
tests/test_linalg.py::test_svd[(16,16)-int64]
158158
tests/test_linalg.py::test_svd[(16,16)-int32]
159-
tests/test_random.py::TestDistributionsRayleigh::test_moments
160159
tests/test_statistics.py::test_median[2-float32]
161160
tests/test_statistics.py::test_median[2-float64]
162161
tests/third_party/cupy/binary_tests/test_elementwise.py::TestElementwise::test_bitwise_and

0 commit comments

Comments
 (0)