Skip to content

Commit fb476ef

Browse files
committed
alphabetical
1 parent 6c50995 commit fb476ef

File tree

2 files changed

+150
-149
lines changed

2 files changed

+150
-149
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 81 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,87 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
410410
return xp.reshape(result, res_shape)
411411

412412

413+
def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
414+
"""
415+
Count the number of unique elements in an array.
416+
417+
Compatible with JAX and Dask, whose laziness would be otherwise
418+
problematic.
419+
420+
Parameters
421+
----------
422+
x : Array
423+
Input array.
424+
xp : array_namespace, optional
425+
The standard-compatible namespace for `x`. Default: infer.
426+
427+
Returns
428+
-------
429+
array: 0-dimensional integer array
430+
The number of unique elements in `x`. It can be lazy.
431+
"""
432+
if xp is None:
433+
xp = array_namespace(x)
434+
435+
if is_jax_array(x):
436+
# size= is JAX-specific
437+
# https://github.com/data-apis/array-api/issues/883
438+
_, counts = xp.unique_counts(x, size=_compat.size(x))
439+
return xp.astype(counts, xp.bool).sum()
440+
441+
_, counts = xp.unique_counts(x)
442+
n = _compat.size(counts)
443+
# FIXME https://github.com/data-apis/array-api-compat/pull/231
444+
if n is None or math.isnan(n): # e.g. Dask, ndonnx
445+
return xp.astype(counts, xp.bool).sum()
446+
return xp.asarray(n, device=_compat.device(x))
447+
448+
449+
def pad(
450+
x: Array,
451+
pad_width: int | tuple[int, int] | list[tuple[int, int]],
452+
*,
453+
constant_values: bool | int | float | complex = 0,
454+
xp: ModuleType,
455+
) -> Array: # numpydoc ignore=PR01,RT01
456+
"""See docstring in `array_api_extra._delegation.py`."""
457+
# make pad_width a list of length-2 tuples of ints
458+
x_ndim = cast(int, x.ndim)
459+
if isinstance(pad_width, int):
460+
pad_width = [(pad_width, pad_width)] * x_ndim
461+
if isinstance(pad_width, tuple):
462+
pad_width = [pad_width] * x_ndim
463+
464+
# https://github.com/python/typeshed/issues/13376
465+
slices: list[slice] = [] # type: ignore[no-any-explicit]
466+
newshape: list[int] = []
467+
for ax, w_tpl in enumerate(pad_width):
468+
if len(w_tpl) != 2:
469+
msg = f"expect a 2-tuple (before, after), got {w_tpl}."
470+
raise ValueError(msg)
471+
472+
sh = x.shape[ax]
473+
if w_tpl[0] == 0 and w_tpl[1] == 0:
474+
sl = slice(None, None, None)
475+
else:
476+
start, stop = w_tpl
477+
stop = None if stop == 0 else -stop
478+
479+
sl = slice(start, stop, None)
480+
sh += w_tpl[0] + w_tpl[1]
481+
482+
newshape.append(sh)
483+
slices.append(sl)
484+
485+
padded = xp.full(
486+
tuple(newshape),
487+
fill_value=constant_values,
488+
dtype=x.dtype,
489+
device=_compat.device(x),
490+
)
491+
return at(padded, tuple(slices)).set(x)
492+
493+
413494
def setdiff1d(
414495
x1: Array,
415496
x2: Array,
@@ -550,84 +631,3 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
550631
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
551632
)
552633
return xp.sin(y) / y
553-
554-
555-
def pad(
556-
x: Array,
557-
pad_width: int | tuple[int, int] | list[tuple[int, int]],
558-
*,
559-
constant_values: bool | int | float | complex = 0,
560-
xp: ModuleType,
561-
) -> Array: # numpydoc ignore=PR01,RT01
562-
"""See docstring in `array_api_extra._delegation.py`."""
563-
# make pad_width a list of length-2 tuples of ints
564-
x_ndim = cast(int, x.ndim)
565-
if isinstance(pad_width, int):
566-
pad_width = [(pad_width, pad_width)] * x_ndim
567-
if isinstance(pad_width, tuple):
568-
pad_width = [pad_width] * x_ndim
569-
570-
# https://github.com/python/typeshed/issues/13376
571-
slices: list[slice] = [] # type: ignore[no-any-explicit]
572-
newshape: list[int] = []
573-
for ax, w_tpl in enumerate(pad_width):
574-
if len(w_tpl) != 2:
575-
msg = f"expect a 2-tuple (before, after), got {w_tpl}."
576-
raise ValueError(msg)
577-
578-
sh = x.shape[ax]
579-
if w_tpl[0] == 0 and w_tpl[1] == 0:
580-
sl = slice(None, None, None)
581-
else:
582-
start, stop = w_tpl
583-
stop = None if stop == 0 else -stop
584-
585-
sl = slice(start, stop, None)
586-
sh += w_tpl[0] + w_tpl[1]
587-
588-
newshape.append(sh)
589-
slices.append(sl)
590-
591-
padded = xp.full(
592-
tuple(newshape),
593-
fill_value=constant_values,
594-
dtype=x.dtype,
595-
device=_compat.device(x),
596-
)
597-
return at(padded, tuple(slices)).set(x)
598-
599-
600-
def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
601-
"""
602-
Count the number of unique elements in an array.
603-
604-
Compatible with JAX and Dask, whose laziness would be otherwise
605-
problematic.
606-
607-
Parameters
608-
----------
609-
x : Array
610-
Input array.
611-
xp : array_namespace, optional
612-
The standard-compatible namespace for `x`. Default: infer.
613-
614-
Returns
615-
-------
616-
array: 0-dimensional integer array
617-
The number of unique elements in `x`. It can be lazy.
618-
"""
619-
if xp is None:
620-
xp = array_namespace(x)
621-
622-
if is_jax_array(x):
623-
# size= is JAX-specific
624-
# https://github.com/data-apis/array-api/issues/883
625-
_, counts = xp.unique_counts(x, size=_compat.size(x))
626-
return xp.astype(counts, xp.bool).sum()
627-
628-
_, counts = xp.unique_counts(x)
629-
n = _compat.size(counts)
630-
# FIXME https://github.com/data-apis/array-api-compat/pull/231
631-
if n is None or math.isnan(n): # e.g. Dask, ndonnx
632-
return xp.astype(counts, xp.bool).sum()
633-
return xp.asarray(n, device=_compat.device(x))

tests/test_funcs.py

Lines changed: 69 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from array_api_extra._lib._utils._compat import device as get_device
2323
from array_api_extra._lib._utils._typing import Array, Device
2424

25+
# some xp backends are untyped
2526
# mypy: disable-error-code=no-untyped-usage
2627

2728

@@ -330,6 +331,74 @@ def test_xp(self, xp: ModuleType):
330331
xp_assert_equal(kron(a, b, xp=xp), k)
331332

332333

334+
class TestNUnique:
335+
def test_simple(self, xp: ModuleType):
336+
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
337+
xp_assert_equal(nunique(a), xp.asarray(3))
338+
339+
def test_empty(self, xp: ModuleType):
340+
a = xp.asarray([])
341+
xp_assert_equal(nunique(a), xp.asarray(0))
342+
343+
def test_device(self, xp: ModuleType, device: Device):
344+
a = xp.asarray(0.0, device=device)
345+
assert get_device(nunique(a)) == device
346+
347+
def test_xp(self, xp: ModuleType):
348+
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
349+
xp_assert_equal(nunique(a, xp=xp), xp.asarray(3))
350+
351+
352+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no arange, no device")
353+
class TestPad:
354+
def test_simple(self, xp: ModuleType):
355+
a = xp.arange(1, 4)
356+
padded = pad(a, 2)
357+
xp_assert_equal(padded, xp.asarray([0, 0, 1, 2, 3, 0, 0]))
358+
359+
def test_fill_value(self, xp: ModuleType):
360+
a = xp.arange(1, 4)
361+
padded = pad(a, 2, constant_values=42)
362+
xp_assert_equal(padded, xp.asarray([42, 42, 1, 2, 3, 42, 42]))
363+
364+
def test_ndim(self, xp: ModuleType):
365+
a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4))
366+
padded = pad(a, 2)
367+
assert padded.shape == (6, 7, 8)
368+
369+
def test_mode_not_implemented(self, xp: ModuleType):
370+
a = xp.arange(3)
371+
with pytest.raises(NotImplementedError, match="Only `'constant'`"):
372+
pad(a, 2, mode="edge") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
373+
374+
def test_device(self, xp: ModuleType, device: Device):
375+
a = xp.asarray(0.0, device=device)
376+
assert get_device(pad(a, 2)) == device
377+
378+
def test_xp(self, xp: ModuleType):
379+
padded = pad(xp.asarray(0), 1, xp=xp)
380+
xp_assert_equal(padded, xp.asarray(0))
381+
382+
def test_tuple_width(self, xp: ModuleType):
383+
a = xp.reshape(xp.arange(12), (3, 4))
384+
padded = pad(a, (1, 0))
385+
assert padded.shape == (4, 5)
386+
387+
padded = pad(a, (1, 2))
388+
assert padded.shape == (6, 7)
389+
390+
with pytest.raises((ValueError, RuntimeError)):
391+
pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType]
392+
393+
def test_list_of_tuples_width(self, xp: ModuleType):
394+
a = xp.reshape(xp.arange(12), (3, 4))
395+
padded = pad(a, [(1, 0), (0, 2)])
396+
assert padded.shape == (4, 6)
397+
398+
padded = pad(a, [(1, 0), (0, 0)])
399+
assert padded.shape == (4, 4)
400+
401+
333402
@pytest.mark.skip_xp_backend(Backend.DASK_ARRAY, reason="no argsort")
334403
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device")
335404
class TestSetDiff1D:
@@ -401,71 +470,3 @@ def test_device(self, xp: ModuleType, device: Device):
401470

402471
def test_xp(self, xp: ModuleType):
403472
xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))
404-
405-
406-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no arange, no device")
407-
class TestPad:
408-
def test_simple(self, xp: ModuleType):
409-
a = xp.arange(1, 4)
410-
padded = pad(a, 2)
411-
xp_assert_equal(padded, xp.asarray([0, 0, 1, 2, 3, 0, 0]))
412-
413-
def test_fill_value(self, xp: ModuleType):
414-
a = xp.arange(1, 4)
415-
padded = pad(a, 2, constant_values=42)
416-
xp_assert_equal(padded, xp.asarray([42, 42, 1, 2, 3, 42, 42]))
417-
418-
def test_ndim(self, xp: ModuleType):
419-
a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4))
420-
padded = pad(a, 2)
421-
assert padded.shape == (6, 7, 8)
422-
423-
def test_mode_not_implemented(self, xp: ModuleType):
424-
a = xp.arange(3)
425-
with pytest.raises(NotImplementedError, match="Only `'constant'`"):
426-
pad(a, 2, mode="edge") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
427-
428-
def test_device(self, xp: ModuleType, device: Device):
429-
a = xp.asarray(0.0, device=device)
430-
assert get_device(pad(a, 2)) == device
431-
432-
def test_xp(self, xp: ModuleType):
433-
padded = pad(xp.asarray(0), 1, xp=xp)
434-
xp_assert_equal(padded, xp.asarray(0))
435-
436-
def test_tuple_width(self, xp: ModuleType):
437-
a = xp.reshape(xp.arange(12), (3, 4))
438-
padded = pad(a, (1, 0))
439-
assert padded.shape == (4, 5)
440-
441-
padded = pad(a, (1, 2))
442-
assert padded.shape == (6, 7)
443-
444-
with pytest.raises((ValueError, RuntimeError)):
445-
pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType]
446-
447-
def test_list_of_tuples_width(self, xp: ModuleType):
448-
a = xp.reshape(xp.arange(12), (3, 4))
449-
padded = pad(a, [(1, 0), (0, 2)])
450-
assert padded.shape == (4, 6)
451-
452-
padded = pad(a, [(1, 0), (0, 0)])
453-
assert padded.shape == (4, 4)
454-
455-
456-
class TestNUnique:
457-
def test_simple(self, xp: ModuleType):
458-
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
459-
xp_assert_equal(nunique(a), xp.asarray(3))
460-
461-
def test_empty(self, xp: ModuleType):
462-
a = xp.asarray([])
463-
xp_assert_equal(nunique(a), xp.asarray(0))
464-
465-
def test_device(self, xp: ModuleType, device: Device):
466-
a = xp.asarray(0.0, device=device)
467-
assert get_device(nunique(a)) == device
468-
469-
def test_xp(self, xp: ModuleType):
470-
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
471-
xp_assert_equal(nunique(a, xp=xp), xp.asarray(3))

0 commit comments

Comments
 (0)