Skip to content

Commit 0e80c50

Browse files
ENH: random.standard_t impl (#515)
* ENH: random.standard_t impl
1 parent 232a01b commit 0e80c50

File tree

8 files changed

+156
-7
lines changed

8 files changed

+156
-7
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,20 @@ INP_DLLEXPORT void dpnp_rng_standard_gamma_c(void* result, _DataType shape, size
948948
template <typename _DataType>
949949
INP_DLLEXPORT void dpnp_rng_standard_normal_c(void* result, size_t size);
950950

951+
/**
952+
* @ingroup BACKEND_API
953+
* @brief math library implementation of random number generator (standard Student's t distribution)
954+
*
955+
* @param [in] size Number of elements in `result` arrays.
956+
*
957+
* @param [in] df Degrees of freedom.
958+
*
959+
* @param [out] result Output array.
960+
*
961+
*/
962+
template <typename _DataType>
963+
INP_DLLEXPORT void dpnp_rng_standard_t_c(void* result, _DataType df, size_t size);
964+
951965
/**
952966
* @ingroup BACKEND_API
953967
* @brief math library implementation of random number generator (uniform distribution)

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ enum class DPNPFuncName : size_t
144144
DPNP_FN_RNG_STANDARD_EXPONENTIAL, /**< Used in numpy.random.standard_exponential() implementation */
145145
DPNP_FN_RNG_STANDARD_GAMMA, /**< Used in numpy.random.standard_gamma() implementation */
146146
DPNP_FN_RNG_STANDARD_NORMAL, /**< Used in numpy.random.standard_normal() implementation */
147+
DPNP_FN_RNG_STANDARD_T, /**< Used in numpy.random.standard_t() implementation */
147148
DPNP_FN_RNG_UNIFORM, /**< Used in numpy.random.uniform() implementation */
148149
DPNP_FN_RNG_WEIBULL, /**< Used in numpy.random.weibull() implementation */
149150
DPNP_FN_SIGN, /**< Used in numpy.sign() implementation */

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,74 @@ void dpnp_rng_standard_normal_c(void* result, size_t size)
564564
dpnp_rng_normal_c(result, mean, stddev, size);
565565
}
566566

567+
template <typename _DataType>
568+
void dpnp_rng_standard_t_c(void* result, _DataType df, size_t size)
569+
{
570+
if (!size)
571+
{
572+
return;
573+
}
574+
cl::sycl::vector_class<cl::sycl::event> no_deps;
575+
576+
_DataType* result1 = reinterpret_cast<_DataType*>(result);
577+
const _DataType d_zero = 0.0, d_one = 1.0;
578+
_DataType shape = df/2;
579+
_DataType *sn = nullptr;
580+
581+
if (dpnp_queue_is_cpu_c())
582+
{
583+
mkl_rng::gamma<_DataType> gamma_distribution(shape, d_zero, 1.0/shape);
584+
auto event_out = mkl_rng::generate(gamma_distribution, DPNP_RNG_ENGINE,
585+
size, result1);
586+
event_out.wait();
587+
event_out = mkl_vm::invsqrt(DPNP_QUEUE, size, result1, result1, no_deps,
588+
mkl_vm::mode::ha);
589+
event_out.wait();
590+
591+
sn = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(size * sizeof(_DataType)));
592+
if (sn == nullptr)
593+
{
594+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_standard_t_c() failed.");
595+
}
596+
597+
mkl_rng::gaussian<_DataType> gaussian_distribution(d_zero, d_one);
598+
event_out = mkl_rng::generate(gaussian_distribution, DPNP_RNG_ENGINE, size, sn);
599+
event_out.wait();
600+
601+
event_out = mkl_vm::mul(DPNP_QUEUE, size, result1, sn, result1, no_deps,
602+
mkl_vm::mode::ha);
603+
dpnp_memory_free_c(sn);
604+
event_out.wait();
605+
}
606+
else
607+
{
608+
int errcode = vdRngGamma(VSL_RNG_METHOD_GAMMA_GNORM_ACCURATE, get_rng_stream(),
609+
size, result1, shape, d_zero, 1.0/shape);
610+
611+
if (errcode != VSL_STATUS_OK)
612+
{
613+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_standard_t_c() failed.");
614+
}
615+
616+
vmdInvSqrt(size, result1, result1, VML_HA);
617+
618+
sn = (_DataType *) mkl_malloc(size * sizeof(_DataType), 64);
619+
if (sn == nullptr)
620+
{
621+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_standard_t_c() failed.");
622+
}
623+
624+
errcode = vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, get_rng_stream(), size, sn,
625+
d_zero, d_one);
626+
if (errcode != VSL_STATUS_OK)
627+
{
628+
throw std::runtime_error("DPNP RNG Error: dpnp_rng_standard_t_c() failed.");
629+
}
630+
vmdMul(size, result1, sn, result1, VML_HA);
631+
mkl_free(sn);
632+
}
633+
}
634+
567635
template <typename _DataType>
568636
void dpnp_rng_uniform_c(void* result, long low, long high, size_t size)
569637
{
@@ -658,6 +726,8 @@ void func_map_init_random(func_map_t& fmap)
658726

659727
fmap[DPNPFuncName::DPNP_FN_RNG_STANDARD_NORMAL][eft_DBL][eft_DBL] = {eft_DBL,
660728
(void*)dpnp_rng_standard_normal_c<double>};
729+
fmap[DPNPFuncName::DPNP_FN_RNG_STANDARD_T][eft_DBL][eft_DBL] = {eft_DBL,
730+
(void*)dpnp_rng_standard_t_c<double>};
661731

662732
fmap[DPNPFuncName::DPNP_FN_RNG_UNIFORM][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_rng_uniform_c<double>};
663733
fmap[DPNPFuncName::DPNP_FN_RNG_UNIFORM][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_rng_uniform_c<float>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
117117
DPNP_FN_RNG_STANDARD_EXPONENTIAL
118118
DPNP_FN_RNG_STANDARD_GAMMA
119119
DPNP_FN_RNG_STANDARD_NORMAL
120+
DPNP_FN_RNG_STANDARD_T
120121
DPNP_FN_RNG_UNIFORM
121122
DPNP_FN_RNG_WEIBULL
122123
DPNP_FN_SIGN

dpnp/random/dpnp_algo_random.pyx

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ __all__ = [
6767
"dpnp_standard_exponential",
6868
"dpnp_standard_gamma",
6969
"dpnp_standard_normal",
70+
"dpnp_standard_t",
7071
"dpnp_uniform",
7172
"dpnp_weibull"
7273
]
@@ -100,6 +101,7 @@ ctypedef void(*fptr_dpnp_rng_standard_cauchy_c_1out_t)(void *, size_t) except +
100101
ctypedef void(*fptr_dpnp_rng_standard_exponential_c_1out_t)(void *, size_t) except +
101102
ctypedef void(*fptr_dpnp_rng_standard_gamma_c_1out_t)(void *, double, size_t) except +
102103
ctypedef void(*fptr_dpnp_rng_standard_normal_c_1out_t)(void *, size_t) except +
104+
ctypedef void(*fptr_dpnp_rng_standard_t_c_1out_t)(void *, double, size_t) except +
103105
ctypedef void(*fptr_dpnp_rng_uniform_c_1out_t)(void *, long, long, size_t) except +
104106
ctypedef void(*fptr_dpnp_rng_weibull_c_1out_t)(void *, double, size_t) except +
105107

@@ -867,6 +869,30 @@ cpdef dparray dpnp_standard_normal(size):
867869

868870
return result
869871

872+
cpdef dparray dpnp_standard_t(double df, size):
873+
"""
874+
Returns an array populated with samples from standard t distribution.
875+
`dpnp_standard_t` generates a matrix filled with random floats sampled from a
876+
univariate standard t distribution for a given number of degrees of freedom.
877+
878+
"""
879+
880+
# convert string type names (dparray.dtype) to C enum DPNPFuncType
881+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(numpy.float64)
882+
883+
# get the FPTR data structure
884+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_RNG_STANDARD_T, param1_type, param1_type)
885+
886+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
887+
# ceate result array with type given by FPTR data
888+
cdef dparray result = dparray(size, dtype=result_type)
889+
890+
cdef fptr_dpnp_rng_standard_t_c_1out_t func = <fptr_dpnp_rng_standard_t_c_1out_t > kernel_data.ptr
891+
# call FPTR function
892+
func(result.get_data(), df, result.size)
893+
894+
return result
895+
870896

871897
cpdef dparray dpnp_uniform(long low, long high, size, dtype):
872898
"""

dpnp/random/dpnp_iface_random.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,20 +1358,38 @@ def standard_normal(size=None):
13581358

13591359

13601360
def standard_t(df, size=None):
1361-
"""Power distribution.
1361+
"""Standard Student’s t distribution.
13621362
1363-
Draw samples from a standard Student’s t distribution with df degrees
1364-
of freedom.
1363+
Draw samples from a standard Student’s t distribution with
1364+
df degrees of freedom.
13651365
13661366
For full documentation refer to :obj:`numpy.random.standard_t`.
13671367
1368-
Notes
1369-
-----
1370-
The function uses `numpy.random.standard_t` on the backend and
1371-
will be executed on fallback backend.
1368+
Limitations
1369+
-----------
1370+
Parameter ``df`` is supported as a scalar.
1371+
Otherwise, :obj:`numpy.random.standard_t(df, size)` samples
1372+
are drawn.
1373+
Output array data type is :obj:`dpnp.float64`.
1374+
1375+
Examples
1376+
--------
1377+
Draw samples from the distribution:
1378+
>>> df = 2.
1379+
>>> s = dpnp.random.standard_t(df, 1000000)
13721380
13731381
"""
13741382

1383+
if not use_origin_backend(df) and dpnp_queue_is_cpu():
1384+
# TODO:
1385+
# array_like of floats for `df`
1386+
if not dpnp.isscalar(df):
1387+
pass
1388+
elif df <= 0:
1389+
pass
1390+
else:
1391+
return dpnp_standard_t(df, size)
1392+
print("here")
13751393
return call_origin(numpy.random.standard_t, df, size)
13761394

13771395

tests/test_random.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,23 @@ def test_seed(self):
671671
self.check_seed('standard_normal', {})
672672

673673

674+
class TestDistributionsStandardT(TestDistribution):
675+
676+
def test_moments(self):
677+
df = 300.0
678+
expected_mean = 0.0
679+
expected_var = df / (df - 2)
680+
self.check_moments('standard_t', expected_mean,
681+
expected_var, {'df': df})
682+
683+
def test_invalid_args(self):
684+
df = 0.0 # positive `df` is expected
685+
self.check_invalid_args('standard_t', {'df': df})
686+
687+
def test_seed(self):
688+
self.check_seed('standard_t', {'df': 10.0})
689+
690+
674691
class TestDistributionsUniform(TestDistribution):
675692

676693
def test_extreme_value(self):

tests_external/skipped_tests_numpy.tbl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,6 +1501,7 @@ tests/test_random.py::TestRandomDist::test_standard_exponential
15011501
tests/test_random.py::TestRandomDist::test_standard_gamma
15021502
tests/test_random.py::TestRandomDist::test_standard_gamma_0
15031503
tests/test_random.py::TestRandomDist::test_standard_normal
1504+
tests/test_random.py::TestRandomDist::test_standard_t
15041505
tests/test_random.py::TestRandomDist::test_uniform
15051506
tests/test_random.py::TestRandomDist::test_uniform_range_bounds
15061507
tests/test_random.py::TestRandomDist::test_weibull
@@ -1595,6 +1596,7 @@ tests/test_randomstate.py::TestRandomDist::test_standard_exponential
15951596
tests/test_randomstate.py::TestRandomDist::test_standard_gamma
15961597
tests/test_randomstate.py::TestRandomDist::test_standard_gamma_0
15971598
tests/test_randomstate.py::TestRandomDist::test_standard_normal
1599+
tests/test_randomstate.py::TestRandomDist::test_standard_t
15981600
tests/test_randomstate.py::TestRandomDist::test_tomaxint
15991601
tests/test_randomstate.py::TestRandomDist::test_uniform
16001602
tests/test_randomstate.py::TestRandomDist::test_uniform_range_bounds

0 commit comments

Comments
 (0)