Skip to content

Commit 38a2f0e

Browse files
authored
Implement diff (#781)
1 parent a61ae4b commit 38a2f0e

File tree

6 files changed

+70
-8
lines changed

6 files changed

+70
-8
lines changed

api_status.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
9090
| | `var` | :white_check_mark: | | |
9191
| Utility Functions | `all` | :white_check_mark: | | |
9292
| | `any` | :white_check_mark: | | |
93-
| | `diff` | :x: | 2024.12 | |
93+
| | `diff` | :white_check_mark: | 2024.12 | |
9494

9595
### Linear Algebra Extension
9696

cubed/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,9 @@
353353

354354
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
355355

356-
from .array_api.utility_functions import all, any
356+
from .array_api.utility_functions import all, any, diff
357357

358-
__all__ += ["all", "any"]
358+
__all__ += ["all", "any", "diff"]
359359

360360
# extensions
361361

cubed/array/overlap.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,16 @@ def get_item_with_depth(
127127
) -> Tuple[slice, ...]:
128128
"""Convert a chunk index to a tuple of slices with depth offsets."""
129129
starts = tuple(_cumsum(c, initial_zero=True) for c in chunks)
130+
131+
def depth_offsets(d):
132+
if isinstance(d, int):
133+
return -d, d
134+
return d
135+
130136
loc = tuple(
131137
(
132-
_clamp(0, start[i] - depth[ax], start[-1]),
133-
_clamp(0, start[i + 1] + depth[ax], start[-1]),
138+
_clamp(0, start[i] + depth_offsets(depth[ax])[0], start[-1]),
139+
_clamp(0, start[i + 1] + depth_offsets(depth[ax])[1], start[-1]),
134140
)
135141
for ax, (i, start) in enumerate(zip(idx, starts))
136142
)
@@ -140,7 +146,7 @@ def get_item_with_depth(
140146
def _pad_boundaries(x, depth, boundary, numblocks, block_id):
141147
for i in range(x.ndim):
142148
d = depth.get(i, 0)
143-
if d == 0 or block_id[i] not in (0, numblocks[i] - 1):
149+
if d == 0 or block_id[i] not in (0, numblocks[i] - 1) or boundary[i] == "none":
144150
continue
145151
pad_shape = list(x.shape)
146152
pad_shape[i] = d

cubed/array_api/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,6 @@
276276

277277
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
278278

279-
from .utility_functions import all, any
279+
from .utility_functions import all, any, diff
280280

281-
__all__ += ["all", "any"]
281+
__all__ += ["all", "any", "diff"]

cubed/array_api/utility_functions.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
from cubed.array.overlap import map_overlap
12
from cubed.array_api.creation_functions import asarray
3+
from cubed.array_api.manipulation_functions import concat
24
from cubed.backend_array_api import namespace as nxp
35
from cubed.core import reduction
6+
from cubed.vendor.dask.array.core import normalize_chunks
7+
from cubed.vendor.dask.array.utils import validate_axis
48

59

610
def all(x, /, *, axis=None, keepdims=False, split_every=None):
@@ -27,3 +31,34 @@ def any(x, /, *, axis=None, keepdims=False, split_every=None):
2731
keepdims=keepdims,
2832
split_every=split_every,
2933
)
34+
35+
36+
def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
37+
axis = validate_axis(axis, x.ndim)
38+
39+
if n < 0:
40+
raise ValueError(f"order of diff must be non-negative, but was {n}")
41+
if n == 0:
42+
return x
43+
44+
combined = []
45+
if prepend is not None:
46+
combined.append(prepend)
47+
combined.append(x)
48+
if append is not None:
49+
combined.append(append)
50+
if len(combined) > 1:
51+
x = concat(combined, axis=axis, chunks=x.chunksize)
52+
53+
shape = tuple(s - n if i == axis else s for i, s in enumerate(x.shape))
54+
chunks = normalize_chunks(x.chunksize, shape, dtype=x.dtype)
55+
depth = {axis: (0, n)} # only need look-ahead values for differencing
56+
return map_overlap(
57+
nxp.diff,
58+
x,
59+
dtype=x.dtype,
60+
chunks=chunks,
61+
depth=depth,
62+
axis=axis,
63+
n=n,
64+
)

cubed/tests/test_array_api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,3 +970,24 @@ def test_all_zero_dimension(spec, executor):
970970
assert b.ndim == 0
971971
assert b.size == 1
972972
assert b.compute(executor=executor)
973+
974+
975+
@pytest.mark.parametrize("n", [1, 2])
976+
def test_diff(n):
977+
x = np.array([1, 5, 3, 8, 7, 2, 6, 9])
978+
a = xp.asarray(x, chunks=(3,))
979+
b = xp.diff(a, n=n)
980+
981+
assert_array_equal(b.compute(), np.diff(x, n=n))
982+
983+
984+
@pytest.mark.parametrize(
985+
("shape", "axis"),
986+
[((10, 15, 20), 0), ((10, 15, 20), 1), ((10, 15, 20), 2), ((10, 15, 20), -1)],
987+
)
988+
@pytest.mark.parametrize("n", [0, 1, 2])
989+
def test_diff_3d(shape, n, axis):
990+
x = np.random.default_rng().integers(0, 10, shape)
991+
a = xp.asarray(x, chunks=(len(shape) * (5,)))
992+
993+
assert_array_equal(xp.diff(a, axis=axis, n=n), np.diff(x, axis=axis, n=n))

0 commit comments

Comments
 (0)