From 3c8e6b0d545c35beb88f4fb84e7b5064c02b305f Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Wed, 20 Nov 2024 14:23:01 -0800 Subject: [PATCH 1/2] simplify dpnp_wrap_reduction_call implementation --- dpnp/dpnp_iface_mathematical.py | 32 ++++++++++--------------- dpnp/dpnp_iface_searching.py | 14 +++++------ dpnp/dpnp_iface_statistics.py | 16 ++++--------- dpnp/dpnp_iface_trigonometric.py | 17 ++++++------- dpnp/dpnp_utils/dpnp_utils_reduction.py | 10 ++------ 5 files changed, 32 insertions(+), 57 deletions(-) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 073de3d8996a..86e0d760d016 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -152,7 +152,7 @@ def _get_max_min(dtype): return f.max, f.min -def _get_reduction_res_dt(a, dtype, _out): +def _get_reduction_res_dt(a, dtype): """Get a data type used by dpctl for result array in reduction function.""" if dtype is None: @@ -1106,11 +1106,10 @@ def cumprod(a, axis=None, dtype=None, out=None): usm_a = dpnp.get_usm_ndarray(a) return dpnp_wrap_reduction_call( - a, + usm_a, out, dpt.cumulative_prod, - _get_reduction_res_dt, - usm_a, + _get_reduction_res_dt(a, dtype), axis=axis, dtype=dtype, ) @@ -1196,11 +1195,10 @@ def cumsum(a, axis=None, dtype=None, out=None): usm_a = dpnp.get_usm_ndarray(a) return dpnp_wrap_reduction_call( - a, + usm_a, out, dpt.cumulative_sum, - _get_reduction_res_dt, - usm_a, + _get_reduction_res_dt(a, dtype), axis=axis, dtype=dtype, ) @@ -1281,11 +1279,10 @@ def cumulative_prod( """ return dpnp_wrap_reduction_call( - x, + dpnp.get_usm_ndarray(x), out, dpt.cumulative_prod, - _get_reduction_res_dt, - dpnp.get_usm_ndarray(x), + _get_reduction_res_dt(x, dtype), axis=axis, dtype=dtype, include_initial=include_initial, @@ -1373,11 +1370,10 @@ def cumulative_sum( """ return dpnp_wrap_reduction_call( - x, + dpnp.get_usm_ndarray(x), out, dpt.cumulative_sum, - _get_reduction_res_dt, - dpnp.get_usm_ndarray(x), + _get_reduction_res_dt(x, dtype), axis=axis, dtype=dtype, include_initial=include_initial, @@ -3524,11 +3520,10 @@ def prod( usm_a = dpnp.get_usm_ndarray(a) return dpnp_wrap_reduction_call( - a, + usm_a, out, dpt.prod, - _get_reduction_res_dt, - usm_a, + _get_reduction_res_dt(a, dtype), axis=axis, dtype=dtype, keepdims=keepdims, @@ -4297,11 +4292,10 @@ def sum( usm_a = dpnp.get_usm_ndarray(a) return dpnp_wrap_reduction_call( - a, + usm_a, out, dpt.sum, - _get_reduction_res_dt, - usm_a, + _get_reduction_res_dt(a, dtype), axis=axis, dtype=dtype, keepdims=keepdims, diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index 192483423765..1b1f7354c462 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -48,14 +48,14 @@ __all__ = ["argmax", "argmin", "argwhere", "searchsorted", "where"] -def _get_search_res_dt(a, _dtype, out): +def _get_search_res_dt(a, out): """Get a data type used by dpctl for result array in search function.""" # get a data type used by dpctl for result array in search function res_dt = dti.default_device_index_type(a.sycl_device) # numpy raises TypeError if "out" data type mismatch default index type - if not dpnp.can_cast(out.dtype, res_dt, casting="safe"): + if out is not None and not dpnp.can_cast(out.dtype, res_dt, casting="safe"): raise TypeError( f"Cannot cast from {out.dtype} to {res_dt} " "according to the rule safe." @@ -143,11 +143,10 @@ def argmax(a, axis=None, out=None, *, keepdims=False): usm_a = dpnp.get_usm_ndarray(a) return dpnp_wrap_reduction_call( - a, + usm_a, out, dpt.argmax, - _get_search_res_dt, - usm_a, + _get_search_res_dt(a, out), axis=axis, keepdims=keepdims, ) @@ -234,11 +233,10 @@ def argmin(a, axis=None, out=None, *, keepdims=False): usm_a = dpnp.get_usm_ndarray(a) return dpnp_wrap_reduction_call( - a, + usm_a, out, dpt.argmin, - _get_search_res_dt, - usm_a, + _get_search_res_dt(a, out), axis=axis, keepdims=keepdims, ) diff --git a/dpnp/dpnp_iface_statistics.py b/dpnp/dpnp_iface_statistics.py index 76d6c16defd1..0142f07ac31e 100644 --- a/dpnp/dpnp_iface_statistics.py +++ b/dpnp/dpnp_iface_statistics.py @@ -115,12 +115,6 @@ def _count_reduce_items(arr, axis, where=True): return items -def _get_comparison_res_dt(a, _dtype, _out): - """Get a data type used by dpctl for result array in comparison function.""" - - return a.dtype - - def amax(a, axis=None, out=None, keepdims=False, initial=None, where=True): """ Return the maximum of an array or maximum along an axis. @@ -760,11 +754,10 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True): usm_a = dpnp.get_usm_ndarray(a) return dpnp_wrap_reduction_call( - a, + usm_a, out, dpt.max, - _get_comparison_res_dt, - usm_a, + a.dtype, axis=axis, keepdims=keepdims, ) @@ -1026,11 +1019,10 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True): usm_a = dpnp.get_usm_ndarray(a) return dpnp_wrap_reduction_call( - a, + usm_a, out, dpt.min, - _get_comparison_res_dt, - usm_a, + a.dtype, axis=axis, keepdims=keepdims, ) diff --git a/dpnp/dpnp_iface_trigonometric.py b/dpnp/dpnp_iface_trigonometric.py index 14a2e4beec61..5acaa60c2ad9 100644 --- a/dpnp/dpnp_iface_trigonometric.py +++ b/dpnp/dpnp_iface_trigonometric.py @@ -98,7 +98,7 @@ ] -def _get_accumulation_res_dt(a, dtype, _out): +def _get_accumulation_res_dt(a, dtype): """Get a dtype used by dpctl for result array in accumulation function.""" if dtype is None: @@ -893,11 +893,10 @@ def cumlogsumexp( usm_x = dpnp.get_usm_ndarray(x) return dpnp_wrap_reduction_call( - x, + usm_x, out, dpt.cumulative_logsumexp, - _get_accumulation_res_dt, - usm_x, + _get_accumulation_res_dt(x, dtype), axis=axis, dtype=dtype, include_initial=include_initial, @@ -1705,11 +1704,10 @@ def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False, out=None): usm_x = dpnp.get_usm_ndarray(x) return dpnp_wrap_reduction_call( - x, + usm_x, out, dpt.logsumexp, - _get_accumulation_res_dt, - usm_x, + _get_accumulation_res_dt(x, dtype), axis=axis, dtype=dtype, keepdims=keepdims, @@ -1952,11 +1950,10 @@ def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False, out=None): usm_x = dpnp.get_usm_ndarray(x) return dpnp_wrap_reduction_call( - x, + usm_x, out, dpt.reduce_hypot, - _get_accumulation_res_dt, - usm_x, + _get_accumulation_res_dt(x, dtype), axis=axis, dtype=dtype, keepdims=keepdims, diff --git a/dpnp/dpnp_utils/dpnp_utils_reduction.py b/dpnp/dpnp_utils/dpnp_utils_reduction.py index 5565c8f13748..8051bcc40c09 100644 --- a/dpnp/dpnp_utils/dpnp_utils_reduction.py +++ b/dpnp/dpnp_utils/dpnp_utils_reduction.py @@ -29,9 +29,7 @@ __all__ = ["dpnp_wrap_reduction_call"] -def dpnp_wrap_reduction_call( - a, out, _reduction_fn, _get_res_dt_fn, *args, **kwargs -): +def dpnp_wrap_reduction_call(a, out, _reduction_fn, res_dt, **kwargs): """Wrap a reduction call from dpctl.tensor interface.""" input_out = out @@ -40,16 +38,12 @@ def dpnp_wrap_reduction_call( else: dpnp.check_supported_arrays_type(out) - # fetch dtype from the passed kwargs to the reduction call - dtype = kwargs.get("dtype", None) - # dpctl requires strict data type matching of out array with the result - res_dt = _get_res_dt_fn(a, dtype, out) if out.dtype != res_dt: out = dpnp.astype(out, dtype=res_dt, copy=False) usm_out = dpnp.get_usm_ndarray(out) kwargs["out"] = usm_out - res_usm = _reduction_fn(*args, **kwargs) + res_usm = _reduction_fn(a, **kwargs) return dpnp.get_result_array(res_usm, input_out, casting="unsafe") From 54f006f3606f6b380e2ea97c39bfee7b68a061dd Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Sat, 14 Dec 2024 10:01:15 -0800 Subject: [PATCH 2/2] address comments --- dpnp/dpnp_utils/dpnp_utils_reduction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dpnp/dpnp_utils/dpnp_utils_reduction.py b/dpnp/dpnp_utils/dpnp_utils_reduction.py index 8051bcc40c09..1fae00a4da51 100644 --- a/dpnp/dpnp_utils/dpnp_utils_reduction.py +++ b/dpnp/dpnp_utils/dpnp_utils_reduction.py @@ -29,7 +29,7 @@ __all__ = ["dpnp_wrap_reduction_call"] -def dpnp_wrap_reduction_call(a, out, _reduction_fn, res_dt, **kwargs): +def dpnp_wrap_reduction_call(usm_a, out, _reduction_fn, res_dt, **kwargs): """Wrap a reduction call from dpctl.tensor interface.""" input_out = out @@ -45,5 +45,5 @@ def dpnp_wrap_reduction_call(a, out, _reduction_fn, res_dt, **kwargs): usm_out = dpnp.get_usm_ndarray(out) kwargs["out"] = usm_out - res_usm = _reduction_fn(a, **kwargs) + res_usm = _reduction_fn(usm_a, **kwargs) return dpnp.get_result_array(res_usm, input_out, casting="unsafe")