Skip to content

Commit 1fa62ca

Browse files
ENH: fptr for random.seed backend (#516)
* ENH: fptr for random.seed backend * rename DPNP_FN_SRAND to DPNP_FN_RNG_SRAND
1 parent 50762cd commit 1fa62ca

File tree

5 files changed

+26
-1
lines changed

5 files changed

+26
-1
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ enum class DPNPFuncName : size_t
142142
DPNP_FN_RNG_POISSON, /**< Used in numpy.random.poisson() implementation */
143143
DPNP_FN_RNG_POWER, /**< Used in numpy.random.power() implementation */
144144
DPNP_FN_RNG_RAYLEIGH, /**< Used in numpy.random.rayleigh() implementation */
145+
DPNP_FN_RNG_SRAND, /**< Used in numpy.random.seed() implementation */
145146
DPNP_FN_RNG_STANDARD_CAUCHY, /**< Used in numpy.random.standard_cauchy() implementation */
146147
DPNP_FN_RNG_STANDARD_EXPONENTIAL, /**< Used in numpy.random.standard_exponential() implementation */
147148
DPNP_FN_RNG_STANDARD_GAMMA, /**< Used in numpy.random.standard_gamma() implementation */

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,7 @@ void func_map_init_random(func_map_t& fmap)
786786
fmap[DPNPFuncName::DPNP_FN_RNG_UNIFORM][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_rng_uniform_c<int>};
787787

788788
fmap[DPNPFuncName::DPNP_FN_RNG_WEIBULL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_rng_weibull_c<double>};
789+
fmap[DPNPFuncName::DPNP_FN_RNG_SRAND][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_srand_c};
789790

790791
return;
791792
}

dpnp/backend/tests/test_random.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <dpnp_iface.hpp>
2+
#include <dpnp_iface_fptr.hpp>
23

34
#include <vector>
45

@@ -85,6 +86,18 @@ TEST (TestBackendRandomUniform, test_seed) {
8586
}
8687
}
8788

89+
TEST (TestBackendRandomSrand, test_func_ptr) {
90+
91+
void * fptr = nullptr;
92+
DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNPFuncName::DPNP_FN_RNG_SRAND,
93+
DPNPFuncType::DPNP_FT_DOUBLE, DPNPFuncType::DPNP_FT_DOUBLE);
94+
95+
fptr = get_dpnp_function_ptr1(kernel_data.return_type, DPNPFuncName::DPNP_FN_RNG_SRAND,
96+
DPNPFuncType::DPNP_FT_DOUBLE, DPNPFuncType::DPNP_FT_DOUBLE);
97+
98+
EXPECT_TRUE(fptr != nullptr);
99+
}
100+
88101
int main(int argc, char **argv) {
89102
::testing::InitGoogleTest(&argc, argv);
90103
return RUN_ALL_TESTS();

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
115115
DPNP_FN_RNG_POISSON
116116
DPNP_FN_RNG_POWER
117117
DPNP_FN_RNG_RAYLEIGH
118+
DPNP_FN_RNG_SRAND
118119
DPNP_FN_RNG_STANDARD_CAUCHY
119120
DPNP_FN_RNG_STANDARD_EXPONENTIAL
120121
DPNP_FN_RNG_STANDARD_GAMMA

dpnp/random/dpnp_algo_random.pyx

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ ctypedef void(*fptr_dpnp_rng_standard_normal_c_1out_t)(void *, size_t) except +
108108
ctypedef void(*fptr_dpnp_rng_standard_t_c_1out_t)(void *, double, size_t) except +
109109
ctypedef void(*fptr_dpnp_rng_uniform_c_1out_t)(void *, long, long, size_t) except +
110110
ctypedef void(*fptr_dpnp_rng_weibull_c_1out_t)(void *, double, size_t) except +
111+
ctypedef void(*fptr_dpnp_srand_c_1out_t)(size_t) except +
111112

112113

113114
cpdef dparray dpnp_beta(double a, double b, size):
@@ -810,7 +811,15 @@ cpdef dpnp_srand(seed):
810811
811812
"""
812813

813-
dpnp_srand_c(seed)
814+
# convert string type names (dparray.dtype) to C enum DPNPFuncType
815+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(numpy.float64)
816+
817+
# get the FPTR data structure
818+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_RNG_SRAND, param1_type, param1_type)
819+
820+
cdef fptr_dpnp_srand_c_1out_t func = < fptr_dpnp_srand_c_1out_t > kernel_data.ptr
821+
# call FPTR function
822+
func(seed)
814823

815824

816825
cpdef dparray dpnp_standard_cauchy(size):

0 commit comments

Comments
 (0)