Skip to content

Commit 4eb0947

Browse files
authored
move multiply() to descriptor usage (#740)
1 parent 77b8270 commit 4eb0947

File tree

7 files changed

+46
-11
lines changed

7 files changed

+46
-11
lines changed

dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,12 @@ MACRO_2ARG_3TYPES_OP(dpnp_fmod_c, cl::sycl::fmod((double)input1_elem, (double)in
6666
MACRO_2ARG_3TYPES_OP(dpnp_hypot_c, cl::sycl::hypot((double)input1_elem, (double)input2_elem), oneapi::mkl::vm::hypot)
6767
MACRO_2ARG_3TYPES_OP(dpnp_maximum_c, cl::sycl::max(input1_elem, input2_elem), oneapi::mkl::vm::fmax)
6868
MACRO_2ARG_3TYPES_OP(dpnp_minimum_c, cl::sycl::min(input1_elem, input2_elem), oneapi::mkl::vm::fmin)
69+
70+
// "multiply" needs to be standalone kernel (not autogenerated) due to complex algorithm. This is not an element wise.
71+
// pytest "tests/third_party/cupy/creation_tests/test_ranges.py::TestMgrid::test_mgrid3"
72+
// requires multiplication shape1[10] with shape2[10,1] and result expected as shape[10,10]
6973
MACRO_2ARG_3TYPES_OP(dpnp_multiply_c, input1_elem* input2_elem, oneapi::mkl::vm::mul)
74+
7075
MACRO_2ARG_3TYPES_OP(dpnp_power_c, cl::sycl::pow((double)input1_elem, (double)input2_elem), oneapi::mkl::vm::pow)
7176
MACRO_2ARG_3TYPES_OP(dpnp_subtract_c, input1_elem - input2_elem, oneapi::mkl::vm::sub)
7277

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,3 +409,26 @@ cdef dparray call_fptr_2in_1out(DPNPFuncName fptr_name, object x1_obj, object x2
409409
x2_dparray.get_data(), x2_dparray.size, x2_shape.data(), x2_shape.size(), NULL)
410410

411411
return result
412+
413+
# this is replacement for "call_fptr_2in_1out". original function must be deleted after transotion.
414+
cdef dparray call_fptr_2in_1out_new(DPNPFuncName fptr_name, dpnp_descriptor x1_obj, dpnp_descriptor x2_obj,
415+
object dtype=None, dparray out=None, object where=True):
416+
# Convert string type names (dparray.dtype) to C enum DPNPFuncType
417+
cdef DPNPFuncType x1_c_type = dpnp_dtype_to_DPNPFuncType(x1_obj.dtype)
418+
cdef DPNPFuncType x2_c_type = dpnp_dtype_to_DPNPFuncType(x2_obj.dtype)
419+
420+
# get the FPTR data structure
421+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, x1_c_type, x2_c_type)
422+
423+
# Create result array
424+
cdef dparray_shape_type x1_shape = x1_obj.shape
425+
cdef dparray_shape_type x2_shape = x2_obj.shape
426+
cdef dparray_shape_type result_shape = get_common_shape(x1_shape, x2_shape)
427+
cdef dparray result = create_output_array(result_shape, kernel_data.return_type, out)
428+
429+
""" Call FPTR function """
430+
cdef fptr_2in_1out_t func = <fptr_2in_1out_t > kernel_data.ptr
431+
func(result.get_data(), x1_obj.get_data(), x1_obj.size, x1_shape.data(), x1_shape.size(),
432+
x2_obj.get_data(), x2_obj.size, x2_shape.data(), x2_shape.size(), NULL)
433+
434+
return result

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ cpdef dparray dpnp_dot(dparray in_array1, dparray in_array2):
6464

6565
# scalar
6666
if dim1 == 0 or dim2 == 0:
67-
return dpnp_multiply(in_array1, in_array2)
67+
x1_desc = dpnp.get_dpnp_descriptor(in_array1)
68+
x2_desc = dpnp.get_dpnp_descriptor(in_array2)
69+
return dpnp_multiply(x1_desc, x2_desc)
6870

6971
cdef size_t size1 = 0
7072
cdef size_t size2 = 0

dpnp/dpnp_algo/dpnp_algo_mathematical.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ cpdef tuple dpnp_modf(dpnp_descriptor x1):
277277

278278

279279
cpdef dparray dpnp_multiply(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
280-
return call_fptr_2in_1out(DPNP_FN_MULTIPLY, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
280+
return call_fptr_2in_1out_new(DPNP_FN_MULTIPLY, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
281281

282282

283283
cpdef dparray dpnp_nancumprod(dparray x1):

dpnp/dpnp_iface_mathematical.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,19 +1071,20 @@ def multiply(x1, x2, dtype=None, out=None, where=True, **kwargs):
10711071
[1, 4, 9, 16, 25]
10721072
10731073
"""
1074-
x1_is_scalar, x2_is_scalar = dpnp.isscalar(x1), dpnp.isscalar(x2)
1075-
x1_is_dparray, x2_is_dparray = isinstance(x1, dparray), isinstance(x2, dparray)
10761074

1077-
if not use_origin_backend(x1) and not kwargs:
1078-
if not x1_is_dparray and not x1_is_scalar:
1079-
pass
1080-
elif not x2_is_dparray and not x2_is_scalar:
1075+
x1_is_scalar = dpnp.isscalar(x1)
1076+
x2_is_scalar = dpnp.isscalar(x2)
1077+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1078+
x2_desc = dpnp.get_dpnp_descriptor(x2)
1079+
1080+
if x1_desc and x2_desc and not kwargs:
1081+
if not x2_desc and not x2_is_scalar:
10811082
pass
10821083
elif x1_is_scalar and x2_is_scalar:
10831084
pass
1084-
elif x1_is_dparray and x1.ndim == 0:
1085+
elif x1_desc and x1_desc.ndim == 0:
10851086
pass
1086-
elif x2_is_dparray and x2.ndim == 0:
1087+
elif x2_desc and x2_desc.ndim == 0:
10871088
pass
10881089
elif dtype is not None:
10891090
pass
@@ -1092,7 +1093,7 @@ def multiply(x1, x2, dtype=None, out=None, where=True, **kwargs):
10921093
elif not where:
10931094
pass
10941095
else:
1095-
return dpnp_multiply(x1, x2, dtype=dtype, out=out, where=where)
1096+
return dpnp_multiply(x1_desc, x2_desc, dtype=dtype, out=out, where=where)
10961097

10971098
return call_origin(numpy.multiply, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
10981099

tests/skipped_tests.tbl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_7_{copy
241241
tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_7_{copy=True, indexing='ij', sparse=True}::test_meshgrid1
242242
tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_7_{copy=True, indexing='ij', sparse=True}::test_meshgrid2
243243
tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_7_{copy=True, indexing='ij', sparse=True}::test_meshgrid3
244+
tests/third_party/cupy/creation_tests/test_ranges.py::TestMgrid::test_mgrid3
245+
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid3
244246
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid4
245247
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid5
246248
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_array_start_stop

tests/skipped_tests_gpu.tbl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_7_{copy
252252
tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_7_{copy=True, indexing='ij', sparse=True}::test_meshgrid1
253253
tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_7_{copy=True, indexing='ij', sparse=True}::test_meshgrid2
254254
tests/third_party/cupy/creation_tests/test_ranges.py::TestMeshgrid_param_7_{copy=True, indexing='ij', sparse=True}::test_meshgrid3
255+
tests/third_party/cupy/creation_tests/test_ranges.py::TestMgrid::test_mgrid3
256+
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid3
255257
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid4
256258
tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid5
257259
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_array_start_stop

0 commit comments

Comments
 (0)