Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added `timeout-minutes` property to GitHub jobs [#2526](https://github.com/IntelPython/dpnp/pull/2526)
* Added implementation of `dpnp.ndarray.data` and `dpnp.ndarray.data.ptr` attributes [#2521](https://github.com/IntelPython/dpnp/pull/2521)
* Added `dpnp.ndarray.__contains__` method [#2534](https://github.com/IntelPython/dpnp/pull/2534)
* Added implementation of `dpnp.piecewise` [#2550](https://github.com/IntelPython/dpnp/pull/2550)

### Changed

Expand Down
144 changes: 143 additions & 1 deletion dpnp/dpnp_iface_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

"""

# pylint: disable=protected-access

from dpctl.tensor._numpy_helper import (
normalize_axis_index,
Expand All @@ -44,7 +45,10 @@

import dpnp

__all__ = ["apply_along_axis", "apply_over_axes"]
# pylint: disable=no-name-in-module
from dpnp.dpnp_utils import get_usm_allocations

__all__ = ["apply_along_axis", "apply_over_axes", "piecewise"]


def apply_along_axis(func1d, axis, arr, *args, **kwargs):
Expand Down Expand Up @@ -266,3 +270,141 @@ def apply_over_axes(func, a, axes):
)
a = res
return res


def piecewise(x, condlist, funclist):
"""
Evaluate a piecewise-defined function.

Given a set of conditions and corresponding functions, evaluate each
function on the input data wherever its condition is true.

For full documentation refer to :obj:`numpy.piecewise`.

Parameters
----------
x : {dpnp.ndarray, usm_ndarray}
The input domain.
condlist : {list of array-like boolean, bool scalars}
Each boolean array/scalar corresponds to a function in `funclist`.
Wherever `condlist[i]` is ``True``, `funclist[i](x)` is used as the
output value.

Each boolean array in `condlist` selects a piece of `x`, and should
therefore be of the same shape as `x`.

The length of `condlist` must correspond to that of `funclist`.
If one extra function is given, i.e. if
``len(funclist) == len(condlist) + 1``, then that extra function
is the default value, used wherever all conditions are ``False``.
funclist : {array-like of scalars}
A constant value is returned wherever corresponding condition of `x`
is ``True``.

Returns
-------
out : dpnp.ndarray
The output is the same shape and type as `x` and is found by
calling the functions in `funclist` on the appropriate portions of `x`,
as defined by the boolean arrays in `condlist`. Portions not covered
by any condition have a default value of ``0``.

Limitations
-----------
Parameters `args` and `kw` are not supported and `funclist` cannot include a
callable functions.

See Also
--------
:obj:`dpnp.choose` : Construct an array from an index array and a set of
arrays to choose from.
:obj:`dpnp.select` : Return an array drawn from elements in `choicelist`,
depending on conditions.
:obj:`dpnp.where` : Return elements from one of two arrays depending
on condition.

Examples
--------
>>> import dpnp as np

Define the signum function, which is -1 for ``x < 0`` and +1 for ``x >= 0``.

>>> x = np.linspace(-2.5, 2.5, 6)
>>> np.piecewise(x, [x < 0, x >= 0], [-1, 1])
array([-1., -1., -1., 1., 1., 1.])

"""
dpnp.check_supported_arrays_type(x)
x_dtype = x.dtype
if isinstance(condlist, dpnp.ndarray) and condlist.ndim in [0, 1]:
condlist = [condlist]
elif dpnp.isscalar(condlist) or (
dpnp.isscalar(condlist[0]) and x.ndim != 0
):
# convert scalar to a list of one array
# convert list of scalars to a list of one array
condlist = [
dpnp.full(
x.shape, condlist, usm_type=x.usm_type, sycl_queue=x.sycl_queue
)
]
elif not isinstance(condlist[0], (dpnp.ndarray)):
# convert list of lists to list of arrays
# convert list of scalars to a list of 0d arrays (for 0d input)
tmp = []
for _, cond in enumerate(condlist):
tmp.append(
dpnp.array(cond, usm_type=x.usm_type, sycl_queue=x.sycl_queue)
)
condlist = tmp

dpnp.check_supported_arrays_type(*condlist)
if dpnp.is_supported_array_type(funclist):
usm_type, exec_q = get_usm_allocations([x, *condlist, funclist])
else:
usm_type, exec_q = get_usm_allocations([x, *condlist])

result = dpnp.empty_like(x, usm_type=usm_type, sycl_queue=exec_q)

condlen = len(condlist)
try:
if isinstance(funclist, str):
raise TypeError
funclen = len(funclist)
except TypeError as e:
raise TypeError("funclist must be a sequence of scalars") from e
if condlen == funclen:
# default value is zero
default_value = x_dtype.type(0)
elif condlen + 1 == funclen:
# default value is the last element of funclist
default_value = funclist[-1]
if callable(default_value):
raise NotImplementedError(
"Callable functions are not supported currently"
)
if isinstance(default_value, dpnp.ndarray):
default_value = default_value.astype(x_dtype)
else:
default_value = x_dtype.type(default_value)
funclist = funclist[:-1]

else:
raise ValueError(
f"with {condlen} condition(s), either {condlen} or {condlen + 1} "
"functions are expected"
)

for condition, func in zip(condlist, funclist):
if callable(func):
raise NotImplementedError(
"Callable functions are not supported currently"
)
if isinstance(func, dpnp.ndarray):
func = func.astype(x_dtype)
else:
func = x_dtype.type(func)
dpnp.where(condition, func, default_value, out=result)
default_value = result

return result
Loading
Loading