Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"array-api": ("https://data-apis.org/array-api/draft", None),
"jax": ("https://jax.readthedocs.io/en/latest", None),
}

Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def isclose(

Parameters
----------
a, b : Array
Input arrays to compare.
a, b : Array | int | float | complex | bool
Input objects to compare. At least one must be an Array API object.
rtol : array_like, optional
The relative tolerance parameter (see Notes).
atol : array_like, optional
Expand Down
11 changes: 7 additions & 4 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_jax_array
from ._utils._helpers import asarrays
from ._utils._typing import Array

__all__ = [
Expand Down Expand Up @@ -315,6 +316,7 @@ def isclose(
xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
a, b = asarrays(a, b, xp=xp)

a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
Expand Down Expand Up @@ -356,8 +358,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:

Parameters
----------
a, b : array
Input arrays.
a, b : Array | int | float | complex
Input arrays or scalars. At least one must be an Array API object.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.

Expand Down Expand Up @@ -420,10 +422,10 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
if xp is None:
xp = array_namespace(a, b)
a, b = asarrays(a, b, xp=xp)

b = xp.asarray(b)
singletons = (1,) * (b.ndim - a.ndim)
a = xp.broadcast_to(xp.asarray(a), singletons + a.shape)
a = xp.broadcast_to(a, singletons + a.shape)

nd_b, nd_a = b.ndim, a.ndim
nd_max = max(nd_b, nd_a)
Expand Down Expand Up @@ -583,6 +585,7 @@ def setdiff1d(
"""
if xp is None:
xp = array_namespace(x1, x2)
x1, x2 = asarrays(x1, x2, xp=xp)

if assume_unique:
x1 = xp.reshape(x1, (-1,))
Expand Down
84 changes: 84 additions & 0 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from __future__ import annotations

from types import ModuleType
from typing import cast

from . import _compat
from ._compat import is_array_api_obj, is_numpy_array
from ._typing import Array

__all__ = ["in1d", "mean"]
Expand Down Expand Up @@ -91,3 +93,85 @@ def mean(
mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
return mean_real + (mean_imag * xp.asarray(1j))
return xp.mean(x, axis=axis, keepdims=keepdims)


def is_python_scalar(x: object) -> bool: # numpydoc ignore=PR01,RT01
"""Return True if `x` is a Python scalar, False otherwise."""
# isinstance(x, float) returns True for np.float64
# isinstance(x, complex) returns True for np.complex128
return isinstance(x, int | float | complex | bool) and not is_numpy_array(x)


def asarrays(
a: Array | int | float | complex | bool,
b: Array | int | float | complex | bool,
xp: ModuleType,
) -> tuple[Array, Array]:
"""
Ensure both `a` and `b` are arrays.

If `b` is a python scalar, it is converted to the same dtype as `a`, and vice versa.

Behavior is not specified when mixing a Python ``float`` and an array with an
integer data type; this may give ``float32``, ``float64``, or raise an exception.
Behavior is implementation-specific.

Similarly, behavior is not specified when mixing a Python ``complex`` and an array
with a real-valued data type; this may give ``complex64``, ``complex128``, or raise
an exception. Behavior is implementation-specific.

Parameters
----------
a, b : Array | int | float | complex | bool
Input arrays or scalars. At least one must be an array.
xp : ModuleType
The array API namespace.

Returns
-------
Array, Array
The input arrays, possibly converted to arrays if they were scalars.

See Also
--------
mixing-arrays-with-python-scalars : Array API specification for the behavior.
"""
a_scalar = is_python_scalar(a)
b_scalar = is_python_scalar(b)
if not a_scalar and not b_scalar:
return a, b # This includes misc. malformed input e.g. str

swap = False
if a_scalar:
swap = True
b, a = a, b

if is_array_api_obj(a):
# a is an Array API object
# b is a int | float | complex | bool

# pyright doesn't like it if you reuse the same variable name
xa = cast(Array, a)

# https://data-apis.org/array-api/draft/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
same_dtype = {
bool: "bool",
int: ("integral", "real floating", "complex floating"),
float: ("real floating", "complex floating"),
complex: "complex floating",
}
kind = same_dtype[type(b)] # type: ignore[index]
if xp.isdtype(xa.dtype, kind):
xb = xp.asarray(b, dtype=xa.dtype)
else:
# Undefined behaviour. Let the function deal with it, if it can.
xb = xp.asarray(b)

else:
# Neither a nor b are Array API objects.
# Note: we can only reach this point when one explicitly passes
# xp=xp to the calling function; otherwise we fail earlier on
# array_namespace(a, b).
xa, xb = xp.asarray(a), xp.asarray(b)

return (xb, xa) if swap else (xa, xb)
68 changes: 56 additions & 12 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,24 @@ def test_none_shape_bool(self, xp: ModuleType):
a = a[a]
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))

@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="Array API 2024.12 support")
def test_python_scalar(self, xp: ModuleType):
a = xp.asarray([0.0, 0.1], dtype=xp.float32)
xp_assert_equal(isclose(a, 0.0), xp.asarray([True, False]))
xp_assert_equal(isclose(0.0, a), xp.asarray([True, False]))

a = xp.asarray([0, 1], dtype=xp.int16)
xp_assert_equal(isclose(a, 0), xp.asarray([True, False]))
xp_assert_equal(isclose(0, a), xp.asarray([True, False]))

xp_assert_equal(isclose(0, 0, xp=xp), xp.asarray(True))
xp_assert_equal(isclose(0, 1, xp=xp), xp.asarray(False))

def test_all_python_scalars(self):
with pytest.raises(TypeError, match="Unrecognized"):
isclose(0, 0)

def test_xp(self, xp: ModuleType):
a = xp.asarray([0.0, 0.0])
b = xp.asarray([1e-9, 1e-4])
Expand All @@ -413,30 +431,22 @@ def test_basic(self, xp: ModuleType):
# Using 0-dimensional array
a = xp.asarray(1)
b = xp.asarray([[1, 2], [3, 4]])
k = xp.asarray([[1, 2], [3, 4]])
xp_assert_equal(kron(a, b), k)
a = xp.asarray([[1, 2], [3, 4]])
b = xp.asarray(1)
xp_assert_equal(kron(a, b), k)
xp_assert_equal(kron(a, b), b)
xp_assert_equal(kron(b, a), b)

# Using 1-dimensional array
a = xp.asarray([3])
b = xp.asarray([[1, 2], [3, 4]])
k = xp.asarray([[3, 6], [9, 12]])
xp_assert_equal(kron(a, b), k)
a = xp.asarray([[1, 2], [3, 4]])
b = xp.asarray([3])
xp_assert_equal(kron(a, b), k)
xp_assert_equal(kron(b, a), k)

# Using 3-dimensional array
a = xp.asarray([[[1]], [[2]]])
b = xp.asarray([[1, 2], [3, 4]])
k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
xp_assert_equal(kron(a, b), k)
a = xp.asarray([[1, 2], [3, 4]])
b = xp.asarray([[[1]], [[2]]])
k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
xp_assert_equal(kron(a, b), k)
xp_assert_equal(kron(b, a), k)

def test_kron_smoke(self, xp: ModuleType):
a = xp.ones((3, 3))
Expand Down Expand Up @@ -474,6 +484,18 @@ def test_kron_shape(
k = kron(a, b)
assert k.shape == expected_shape

def test_python_scalar(self, xp: ModuleType):
a = 1
# Test no dtype promotion to xp.asarray(a); use b.dtype
b = xp.asarray([[1, 2], [3, 4]], dtype=xp.int16)
xp_assert_equal(kron(a, b), b)
xp_assert_equal(kron(b, a), b)
xp_assert_equal(kron(1, 1, xp=xp), xp.asarray(1))

def test_all_python_scalars(self):
with pytest.raises(TypeError, match="Unrecognized"):
kron(1, 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does NOT fail if xp=jax.numpy, because xp_lazy_function converts everything to jax.


def test_device(self, xp: ModuleType, device: Device):
x1 = xp.asarray([1, 2, 3], device=device)
x2 = xp.asarray([4, 5], device=device)
Expand Down Expand Up @@ -601,6 +623,28 @@ def test_shapes(
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
xp_assert_equal(actual, xp.empty((0,)))

@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
@pytest.mark.parametrize("assume_unique", [True, False])
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
x2 = 3
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))

actual = setdiff1d(x2, x1, assume_unique=assume_unique)
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))

xp_assert_equal(
setdiff1d(0, 0, assume_unique=assume_unique, xp=xp),
xp.asarray([0])[:0], # Default int dtype for backend
)

@pytest.mark.parametrize("assume_unique", [True, False])
def test_all_python_scalars(self, assume_unique: bool):
with pytest.raises(TypeError, match="Unrecognized"):
setdiff1d(0, 0, assume_unique=assume_unique)

def test_device(self, xp: ModuleType, device: Device):
x1 = xp.asarray([3, 8, 20], device=device)
x2 = xp.asarray([2, 3, 4], device=device)
Expand Down
97 changes: 96 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from types import ModuleType

import numpy as np
import pytest

from array_api_extra._lib import Backend
from array_api_extra._lib._testing import xp_assert_equal
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._helpers import in1d
from array_api_extra._lib._utils._helpers import asarrays, in1d
from array_api_extra._lib._utils._typing import Device
from array_api_extra.testing import lazy_xp_function

Expand Down Expand Up @@ -45,3 +46,97 @@ def test_xp(self, xp: ModuleType):
expected = xp.asarray([True, False])
actual = in1d(x1, x2, xp=xp)
xp_assert_equal(actual, expected)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.parametrize(
("dtype", "b", "defined"),
[
# Well-defined cases of dtype promotion from Python scalar to Array
# bool vs. bool
("bool", True, True),
# int vs. xp.*int*, xp.float*, xp.complex*
("int16", 1, True),
("uint8", 1, True),
("float32", 1, True),
("float64", 1, True),
("complex64", 1, True),
("complex128", 1, True),
# float vs. xp.float, xp.complex
("float32", 1.0, True),
("float64", 1.0, True),
("complex64", 1.0, True),
("complex128", 1.0, True),
# complex vs. xp.complex
("complex64", 1.0j, True),
("complex128", 1.0j, True),
# Undefined cases
("bool", 1, False),
("int64", 1.0, False),
("float64", 1.0j, False),
],
)
def test_asarrays_array_vs_scalar(
dtype: str, b: int | float | complex, defined: bool, xp: ModuleType
):
a = xp.asarray(1, dtype=getattr(xp, dtype))

xa, xb = asarrays(a, b, xp)
assert xa.dtype == a.dtype
if defined:
assert xb.dtype == a.dtype
else:
assert xb.dtype == xp.asarray(b).dtype

xbr, xar = asarrays(b, a, xp)
assert xar.dtype == xa.dtype
assert xbr.dtype == xb.dtype


def test_asarrays_scalar_vs_scalar(xp: ModuleType):
a, b = asarrays(1, 2.2, xp=xp)
assert a.dtype == xp.asarray(1).dtype # Default dtype
assert b.dtype == xp.asarray(2.2).dtype # Default dtype; not broadcasted


ALL_TYPES = (
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float32",
"float64",
"complex64",
"complex128",
"bool",
)


@pytest.mark.parametrize("a_type", ALL_TYPES)
@pytest.mark.parametrize("b_type", ALL_TYPES)
def test_asarrays_array_vs_array(a_type: str, b_type: str, xp: ModuleType):
"""
Test that when both inputs of asarray are already Array API objects,
they are returned unchanged.
"""
a = xp.asarray(1, dtype=getattr(xp, a_type))
b = xp.asarray(1, dtype=getattr(xp, b_type))
xa, xb = asarrays(a, b, xp)
assert xa.dtype == a.dtype
assert xb.dtype == b.dtype


@pytest.mark.parametrize("dtype", [np.float64, np.complex128])
def test_asarrays_numpy_generics(dtype: type):
"""
Test special case of np.float64 and np.complex128,
which are subclasses of float and complex.
"""
a = dtype(0)
xa, xb = asarrays(a, 0, xp=np)
assert xa.dtype == dtype
assert xb.dtype == dtype