Skip to content

Commit 6878157

Browse files
committed
xp.where(..., copy=None)
1 parent ee25aae commit 6878157

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

array_api_compat/common/_helpers.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,54 @@ def size(x):
801801
return None
802802
return math.prod(x.shape)
803803

804+
805+
def is_writeable_array(x):
806+
"""
807+
Return False if x.__setitem__ is expected to raise; True otherwise
808+
"""
809+
if is_numpy_array(x):
810+
return x.flags.writeable
811+
if is_jax_array(x) or is_pydata_sparse_array(x):
812+
return False
813+
return True
814+
815+
816+
def where(condition, x=None, y=None, /, copy: bool | None = True):
817+
"""Return elements from x when condition is True and from y when
818+
it is False.
819+
820+
This is a wrapper around xp.where that adds the copy parameter:
821+
822+
None
823+
x *may* be modified in place if it is possible and beneficial
824+
for performance. You should not use x after calling this function.
825+
True
826+
Ensure that the inputs are not modified.
827+
This is the default, in line with np.where.
828+
False
829+
Raise ValueError if a copy cannot be avoided.
830+
"""
831+
if x is None and y is None:
832+
xp = array_namespace(condition, use_compat=False)
833+
return xp.where(condition)
834+
835+
if copy is False:
836+
if not is_writeable_array(x):
837+
raise ValueError("Cannot modify parameter in place")
838+
elif copy is None:
839+
copy = not is_writeable_array(x)
840+
elif copy is not True:
841+
raise ValueError(f"Invalid value for copy: {copy!r}")
842+
843+
xp = array_namespace(condition, x, y, use_compat=False)
844+
if copy:
845+
return xp.where(condition, x, y)
846+
else:
847+
condition, x, y = xp.broadcast_arrays(condition, x, y)
848+
x[condition] = y[condition]
849+
return x
850+
851+
804852
__all__ = [
805853
"array_namespace",
806854
"device",
@@ -821,8 +869,10 @@ def size(x):
821869
"is_ndonnx_namespace",
822870
"is_pydata_sparse_array",
823871
"is_pydata_sparse_namespace",
872+
"is_writeable_array",
824873
"size",
825874
"to_device",
875+
"where",
826876
]
827877

828878
_all_ignore = ['sys', 'math', 'inspect', 'warnings']

docs/helper-functions.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ instead, which would be wrapped.
3636
.. autofunction:: device
3737
.. autofunction:: to_device
3838
.. autofunction:: size
39+
.. autofunction:: where
3940

4041
Inspection Helpers
4142
------------------
@@ -51,6 +52,7 @@ yet.
5152
.. autofunction:: is_jax_array
5253
.. autofunction:: is_pydata_sparse_array
5354
.. autofunction:: is_ndonnx_array
55+
.. autofunction:: is_writeable_array
5456
.. autofunction:: is_numpy_namespace
5557
.. autofunction:: is_cupy_namespace
5658
.. autofunction:: is_torch_namespace

0 commit comments

Comments
 (0)