Skip to content

Commit 40398e3

Browse files
committed
Merge branch 'main' into lazywhere
2 parents e307525 + 9f03a41 commit 40398e3

File tree

10 files changed

+809
-659
lines changed

10 files changed

+809
-659
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
- uses: pre-commit/[email protected]
3131
with:
3232
extra_args: --hook-stage manual --all-files
33-
- uses: prefix-dev/[email protected].2
33+
- uses: prefix-dev/[email protected].3
3434
with:
3535
pixi-version: v0.40.3
3636
cache: true
@@ -56,7 +56,7 @@ jobs:
5656
with:
5757
fetch-depth: 0
5858

59-
- uses: prefix-dev/[email protected].2
59+
- uses: prefix-dev/[email protected].3
6060
with:
6161
pixi-version: v0.40.3
6262
cache: true

.github/workflows/docs-build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jobs:
77
runs-on: ubuntu-latest
88
steps:
99
- uses: actions/checkout@v4
10-
- uses: prefix-dev/[email protected].2
10+
- uses: prefix-dev/[email protected].3
1111
with:
1212
pixi-version: v0.40.3
1313
cache: true

pixi.lock

Lines changed: 591 additions & 580 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
sinc,
1616
)
1717

18-
__version__ = "0.6.1.dev0"
18+
__version__ = "0.7.0.dev0"
1919

2020
# pylint: disable=duplicate-code
2121
__all__ = [

src/array_api_extra/_lib/_at.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -278,16 +278,11 @@ def _op(
278278
msg = f"copy must be True, False, or None; got {copy!r}"
279279
raise ValueError(msg)
280280

281-
if copy is None:
282-
writeable = is_writeable_array(x)
283-
copy = not writeable
284-
elif copy:
285-
writeable = None
286-
else:
287-
writeable = is_writeable_array(x)
281+
writeable = None if copy else is_writeable_array(x)
288282

289-
# JAX inside jax.jit and Dask don't support in-place updates with boolean
290-
# mask. However we can handle the common special case of 0-dimensional y
283+
# JAX inside jax.jit doesn't support in-place updates with boolean
284+
# masks; Dask exclusively supports __setitem__ but not iops.
285+
# We can handle the common special case of 0-dimensional y
291286
# with where(idx, y, x) instead.
292287
if (
293288
(is_dask_array(idx) or is_jax_array(idx))
@@ -296,23 +291,24 @@ def _op(
296291
):
297292
y_xp = xp.asarray(y, dtype=x.dtype)
298293
if y_xp.ndim == 0:
299-
if out_of_place_op:
294+
if out_of_place_op: # add(), subtract(), ...
300295
# suppress inf warnings on Dask
301296
out = apply_where(
302297
idx, (x, y_xp), out_of_place_op, fill_value=x, xp=xp
303298
)
304299
# Undo int->float promotion on JAX after _AtOp.DIVIDE
305300
out = xp.astype(out, x.dtype, copy=False)
306-
else:
301+
else: # set()
307302
out = xp.where(idx, y_xp, x)
308303

309-
if copy:
310-
return out
311-
x[()] = out
312-
return x
304+
if copy is False:
305+
x[()] = out
306+
return x
307+
return out
308+
313309
# else: this will work on eager JAX and crash on jax.jit and Dask
314310

315-
if copy:
311+
if copy or (copy is None and not writeable):
316312
if is_jax_array(x):
317313
# Use JAX's at[]
318314
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
@@ -336,7 +332,7 @@ def _op(
336332
msg = f"Can't update read-only array {x}"
337333
raise ValueError(msg)
338334

339-
if in_place_op:
335+
if in_place_op: # add(), subtract(), ...
340336
x[self._idx] = in_place_op(x[self._idx], y)
341337
else: # set()
342338
x[self._idx] = y

src/array_api_extra/_lib/_funcs.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
14+
from ._utils._compat import array_namespace, is_jax_array
15+
from ._utils._helpers import asarrays, ndindex
1416
from ._utils._compat import (
1517
array_namespace,
1618
is_dask_namespace,
@@ -384,7 +386,7 @@ def create_diagonal(
384386
Parameters
385387
----------
386388
x : array
387-
A 1-D array.
389+
An array having shape ``(*batch_dims, k)``.
388390
offset : int, optional
389391
Offset from the leading diagonal (default is ``0``).
390392
Use positive ints for diagonals above the leading diagonal,
@@ -395,7 +397,8 @@ def create_diagonal(
395397
Returns
396398
-------
397399
array
398-
A 2-D array with `x` on the diagonal (offset by `offset`).
400+
An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x`
401+
on the diagonal (offset by `offset`).
399402
400403
Examples
401404
--------
@@ -418,18 +421,21 @@ def create_diagonal(
418421
if xp is None:
419422
xp = array_namespace(x)
420423

421-
if x.ndim != 1:
422-
err_msg = "`x` must be 1-dimensional."
424+
if x.ndim == 0:
425+
err_msg = "`x` must be at least 1-dimensional."
423426
raise ValueError(err_msg)
424-
n = x.shape[0] + abs(offset)
425-
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
426-
427-
start = offset if offset >= 0 else abs(offset) * n
428-
stop = min(n * (n - offset), diag.shape[0])
429-
step = n + 1
430-
diag = at(diag)[start:stop:step].set(x)
431-
432-
return xp.reshape(diag, (n, n))
427+
batch_dims = x.shape[:-1]
428+
n = x.shape[-1] + abs(offset)
429+
diag = xp.zeros((*batch_dims, n**2), dtype=x.dtype, device=_compat.device(x))
430+
431+
target_slice = slice(
432+
offset if offset >= 0 else abs(offset) * n,
433+
min(n * (n - offset), diag.shape[-1]),
434+
n + 1,
435+
)
436+
for index in ndindex(*batch_dims):
437+
diag = at(diag)[(*index, target_slice)].set(x[(*index, slice(None))])
438+
return xp.reshape(diag, (*batch_dims, n, n))
433439

434440

435441
def expand_dims(

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44
from __future__ import annotations
55

6+
from collections.abc import Generator
67
from types import ModuleType
78
from typing import cast
89

@@ -182,6 +183,29 @@ def asarrays(
182183
return (xb, xa) if swap else (xa, xb)
183184

184185

186+
def ndindex(*x: int) -> Generator[tuple[int, ...]]:
187+
"""
188+
Generate all N-dimensional indices for a given array shape.
189+
190+
Given the shape of an array, an ndindex instance iterates over the N-dimensional
191+
index of the array. At each iteration a tuple of indices is returned, the last
192+
dimension is iterated over first.
193+
194+
This has an identical API to numpy.ndindex.
195+
196+
Parameters
197+
----------
198+
*x : int
199+
The shape of the array.
200+
"""
201+
if not x:
202+
yield ()
203+
return
204+
for i in ndindex(*x[:-1]):
205+
for j in range(x[-1]):
206+
yield *i, j
207+
208+
185209
def meta_namespace(
186210
*arrays: Array | int | float | complex | bool | None,
187211
xp: ModuleType | None = None,

0 commit comments

Comments
 (0)