Skip to content

implement dpnp.piecewise #2550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added `timeout-minutes` property to GitHub jobs [#2526](https://github.com/IntelPython/dpnp/pull/2526)
* Added implementation of `dpnp.ndarray.data` and `dpnp.ndarray.data.ptr` attributes [#2521](https://github.com/IntelPython/dpnp/pull/2521)
* Added `dpnp.ndarray.__contains__` method [#2534](https://github.com/IntelPython/dpnp/pull/2534)
* Added implementation of `dpnp.piecewise` [#2550](https://github.com/IntelPython/dpnp/pull/2550)

### Changed

Expand Down
144 changes: 143 additions & 1 deletion dpnp/dpnp_iface_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

"""

# pylint: disable=protected-access

from dpctl.tensor._numpy_helper import (
normalize_axis_index,
Expand All @@ -44,7 +45,10 @@

import dpnp

__all__ = ["apply_along_axis", "apply_over_axes"]
# pylint: disable=no-name-in-module
from dpnp.dpnp_utils import get_usm_allocations

__all__ = ["apply_along_axis", "apply_over_axes", "piecewise"]


def apply_along_axis(func1d, axis, arr, *args, **kwargs):
Expand Down Expand Up @@ -266,3 +270,141 @@ def apply_over_axes(func, a, axes):
)
a = res
return res


def piecewise(x, condlist, funclist):
"""
Evaluate a piecewise-defined function.

Given a set of conditions and corresponding functions, evaluate each
function on the input data wherever its condition is true.

For full documentation refer to :obj:`numpy.piecewise`.

Parameters
----------
x : {dpnp.ndarray, usm_ndarray}
The input domain.
condlist : {list of array-like boolean, bool scalars}
Each boolean array/scalar corresponds to a function in `funclist`.
Wherever `condlist[i]` is ``True``, `funclist[i](x)` is used as the
output value.

Each boolean array in `condlist` selects a piece of `x`, and should
therefore be of the same shape as `x`.

The length of `condlist` must correspond to that of `funclist`.
If one extra function is given, i.e. if
``len(funclist) == len(condlist) + 1``, then that extra function
is the default value, used wherever all conditions are ``False``.
funclist : {array-like of scalars}
A constant value is returned wherever corresponding condition of `x`
is ``True``.

Returns
-------
out : dpnp.ndarray
The output is the same shape and type as `x` and is found by
calling the functions in `funclist` on the appropriate portions of `x`,
as defined by the boolean arrays in `condlist`. Portions not covered
by any condition have a default value of ``0``.

Limitations
-----------
Parameters `args` and `kw` are not supported and `funclist` cannot include a
callable functions.

See Also
--------
:obj:`dpnp.choose` : Construct an array from an index array and a set of
arrays to choose from.
:obj:`dpnp.select` : Return an array drawn from elements in `choicelist`,
depending on conditions.
:obj:`dpnp.where` : Return elements from one of two arrays depending
on condition.

Examples
--------
>>> import dpnp as np

Define the signum function, which is -1 for ``x < 0`` and +1 for ``x >= 0``.

>>> x = np.linspace(-2.5, 2.5, 6)
>>> np.piecewise(x, [x < 0, x >= 0], [-1, 1])
array([-1., -1., -1., 1., 1., 1.])

"""
dpnp.check_supported_arrays_type(x)
x_dtype = x.dtype
if isinstance(condlist, dpnp.ndarray) and condlist.ndim in [0, 1]:
condlist = [condlist]
elif dpnp.isscalar(condlist) or (
dpnp.isscalar(condlist[0]) and x.ndim != 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no check that condlist is exactly a list. Do we need that or it is assuming to accept any sequence?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, condlist can be a list, tuple or ndarray in both NumPy and CuPy. And dpnp follows the same practice.

import numpy
a = numpy.linspace(-2.5, 2.5, 6)
numpy.piecewise(a, [a < 0, a >= 0], [-1, 1])  # condlist is list
# array([-1., -1., -1.,  1.,  1.,  1.])

numpy.piecewise(a, (a < 0, a >= 0), [-1, 1])  # condlist is tuple
# array([-1., -1., -1.,  1.,  1.,  1.])

numpy.piecewise(a, numpy.array([a < 0, a >= 0]), [-1, 1])  # condlist is array
# array([-1., -1., -1.,  1.,  1.,  1.])

import cupy
x = cupy.array(a)

cupy.piecewise(x, [x < 0, x >= 0], [-1, 1])  # condlist is list
# array([-1., -1., -1.,  1.,  1.,  1.])

cupy.piecewise(x, (x < 0, x >= 0), [-1, 1])  # condlist is tuple
# array([-1., -1., -1.,  1.,  1.,  1.])

cupy.piecewise(x, cupy.array([x < 0, x >= 0]), [-1, 1])  # condlist is array
# array([-1., -1., -1.,  1.,  1.,  1.])

import dpnp
ia = dpnp.array(a)
dpnp.piecewise(ia, [ia < 0, ia >= 0], [-1, 1])  # condlist is list
# array([-1., -1., -1.,  1.,  1.,  1.])

dpnp.piecewise(ia, (ia < 0, ia >= 0), [-1, 1])  # condlist is tuple
# array([-1., -1., -1.,  1.,  1.,  1.])

dpnp.piecewise(ia, dpnp.array([ia < 0, ia >= 0]), [-1, 1])  # condlist is array
# array([-1., -1., -1.,  1.,  1.,  1.])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to raise an error like ValueError if condlist is empty?

dpnp.piecewise(x, [], [1])
# output
IndexError: list index out of range

):
# convert scalar to a list of one array
# convert list of scalars to a list of one array
condlist = [
dpnp.full(
x.shape, condlist, usm_type=x.usm_type, sycl_queue=x.sycl_queue
)
]
elif not isinstance(condlist[0], (dpnp.ndarray)):
# convert list of lists to list of arrays
# convert list of scalars to a list of 0d arrays (for 0d input)
tmp = []
for _, cond in enumerate(condlist):
tmp.append(
dpnp.array(cond, usm_type=x.usm_type, sycl_queue=x.sycl_queue)
)
condlist = tmp

dpnp.check_supported_arrays_type(*condlist)
if dpnp.is_supported_array_type(funclist):
usm_type, exec_q = get_usm_allocations([x, *condlist, funclist])
else:
usm_type, exec_q = get_usm_allocations([x, *condlist])

result = dpnp.empty_like(x, usm_type=usm_type, sycl_queue=exec_q)

condlen = len(condlist)
try:
if isinstance(funclist, str):
raise TypeError
funclen = len(funclist)
except TypeError as e:
raise TypeError("funclist must be a sequence of scalars") from e
if condlen == funclen:
# default value is zero
default_value = x_dtype.type(0)
elif condlen + 1 == funclen:
# default value is the last element of funclist
default_value = funclist[-1]
if callable(default_value):
raise NotImplementedError(
"Callable functions are not supported currently"
)
if isinstance(default_value, dpnp.ndarray):
default_value = default_value.astype(x_dtype)
else:
default_value = x_dtype.type(default_value)
funclist = funclist[:-1]

else:
raise ValueError(
f"with {condlen} condition(s), either {condlen} or {condlen + 1} "
"functions are expected"
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please handle this case as well

x = dpnp.ones(3)
dpnp.piecewise(x, [x>0], dpnp.array([[1,2,3]]))
# output 
RuntimeError: Unable to cast Python instance of type <class 'dpnp.dpnp_array.dpnp_array'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)

for condition, func in zip(condlist, funclist):
if callable(func):
raise NotImplementedError(
"Callable functions are not supported currently"
)
if isinstance(func, dpnp.ndarray):
func = func.astype(x_dtype)
else:
func = x_dtype.type(func)
dpnp.where(condition, func, default_value, out=result)
default_value = result

return result
Loading
Loading