Skip to content

Commit 054d54c

Browse files
authored
dpnp.cov to desc (#847)
1 parent 3d28cca commit 054d54c

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ cpdef dparray dpnp_transpose(dpnp_descriptor array1, axes=*)
346346
"""
347347
Statistics functions
348348
"""
349-
cpdef dparray dpnp_cov(dparray array1)
349+
cpdef dpnp_descriptor dpnp_cov(dpnp_descriptor array1)
350350
cpdef dparray dpnp_mean(dparray a, axis)
351351
cpdef dparray dpnp_min(dparray a, axis)
352352

dpnp/dpnp_algo/dpnp_algo_statistics.pyx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ cpdef dparray dpnp_correlate(utils.dpnp_descriptor x1, utils.dpnp_descriptor x2)
113113
return result
114114

115115

116-
cpdef dparray dpnp_cov(dparray array1):
116+
# supports "double" input only
117+
cpdef utils.dpnp_descriptor dpnp_cov(utils.dpnp_descriptor array1):
117118
cdef shape_type_c input_shape = array1.shape
118119

119120
if array1.ndim == 1:
@@ -125,14 +126,13 @@ cpdef dparray dpnp_cov(dparray array1):
125126
# get the FPTR data structure
126127
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_COV, param1_type, param1_type)
127128

128-
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
129129
# ceate result array with type given by FPTR data
130-
in_array = array1.astype(result_type)
131-
cdef dparray result = dparray((input_shape[0], input_shape[0]), dtype=result_type)
130+
cdef shape_type_c result_shape = (input_shape[0], input_shape[0])
131+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
132132

133133
cdef fptr_custom_cov_1in_1out_t func = <fptr_custom_cov_1in_1out_t > kernel_data.ptr
134134
# call FPTR function
135-
func(in_array.get_data(), result.get_data(), input_shape[0], input_shape[1])
135+
func(array1.get_data(), result.get_data(), input_shape[0], input_shape[1])
136136

137137
return result
138138

dpnp/dpnp_iface_statistics.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,11 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
291291
elif aweights is not None:
292292
pass
293293
else:
294-
return dpnp_cov(x1)
294+
if x1_desc.dtype != dpnp.float64:
295+
x1_double_container = x1.astype(dpnp.float64)
296+
x1_desc = dpnp.get_dpnp_descriptor(x1_double_container)
297+
298+
return dpnp_cov(x1_desc).get_pyobj()
295299

296300
return call_origin(numpy.cov, x1, y, rowvar, bias, ddof, fweights, aweights)
297301

0 commit comments

Comments
 (0)