Skip to content

Commit 9ea2351

Browse files
ENH: misc update; random.weibull (#724)
1 parent d98abe4 commit 9ea2351

File tree

2 files changed

+23
-24
lines changed

2 files changed

+23
-24
lines changed

dpnp/backend/kernels/dpnp_krnl_random.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,18 +1505,21 @@ void dpnp_rng_weibull_c(void* result, const double alpha, const size_t size)
15051505
{
15061506
return;
15071507
}
1508-
_DataType* result1 = reinterpret_cast<_DataType*>(result);
15091508

1510-
// set displacement a
1511-
const _DataType a = (_DataType(0.0));
1512-
1513-
// set beta
1514-
const _DataType beta = (_DataType(1.0));
1509+
if (alpha == 0)
1510+
{
1511+
dpnp_zeros_c<_DataType>(result, size);
1512+
}
1513+
else
1514+
{
1515+
_DataType* result1 = reinterpret_cast<_DataType*>(result);
1516+
const _DataType a = (_DataType(0.0));
1517+
const _DataType beta = (_DataType(1.0));
15151518

1516-
mkl_rng::weibull<_DataType> distribution(alpha, a, beta);
1517-
// perform generation
1518-
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
1519-
event_out.wait();
1519+
mkl_rng::weibull<_DataType> distribution(alpha, a, beta);
1520+
auto event_out = mkl_rng::generate(distribution, DPNP_RNG_ENGINE, size, result1);
1521+
event_out.wait();
1522+
}
15201523
}
15211524

15221525
template <typename _DataType>

dpnp/random/dpnp_algo_random.pyx

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,23 +1133,19 @@ cpdef dparray dpnp_rng_weibull(double a, size):
11331133
cdef DPNPFuncData kernel_data
11341134
cdef fptr_dpnp_rng_weibull_c_1out_t func
11351135

1136-
if a == 0.0:
1137-
result = dparray(size, dtype=dtype)
1138-
result.fill(0.0)
1139-
else:
1140-
# convert string type names (dparray.dtype) to C enum DPNPFuncType
1141-
param1_type = dpnp_dtype_to_DPNPFuncType(numpy.float64)
1136+
# convert string type names (dparray.dtype) to C enum DPNPFuncType
1137+
param1_type = dpnp_dtype_to_DPNPFuncType(numpy.float64)
11421138

1143-
# get the FPTR data structure
1144-
kernel_data = get_dpnp_function_ptr(DPNP_FN_RNG_WEIBULL, param1_type, param1_type)
1139+
# get the FPTR data structure
1140+
kernel_data = get_dpnp_function_ptr(DPNP_FN_RNG_WEIBULL, param1_type, param1_type)
11451141

1146-
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
1147-
# ceate result array with type given by FPTR data
1148-
result = dparray(size, dtype=result_type)
1142+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
1143+
# ceate result array with type given by FPTR data
1144+
result = dparray(size, dtype=result_type)
11491145

1150-
func = <fptr_dpnp_rng_weibull_c_1out_t > kernel_data.ptr
1151-
# call FPTR function
1152-
func(result.get_data(), a, result.size)
1146+
func = <fptr_dpnp_rng_weibull_c_1out_t > kernel_data.ptr
1147+
# call FPTR function
1148+
func(result.get_data(), a, result.size)
11531149

11541150
return result
11551151

0 commit comments

Comments
 (0)