Skip to content

Commit 5e54646

Browse files
tomwhitethodson-usgs
authored andcommitted
Implement flip (#528)
1 parent e107d7c commit 5e54646

File tree

6 files changed

+74
-2
lines changed

6 files changed

+74
-2
lines changed

.github/workflows/array-api-tests.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ jobs:
9696
# not implemented
9797
array_api_tests/test_array_object.py::test_setitem
9898
array_api_tests/test_array_object.py::test_setitem_masking
99-
array_api_tests/test_manipulation_functions.py::test_flip
10099
array_api_tests/test_sorting_functions.py
101100
array_api_tests/test_statistical_functions.py::test_std
102101
array_api_tests/test_statistical_functions.py::test_var

api_status.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
5959
| | `broadcast_to` | :white_check_mark: | | |
6060
| | `concat` | :white_check_mark: | | |
6161
| | `expand_dims` | :white_check_mark: | | |
62-
| | `flip` | :x: | | Needs indexing with step=-1, [#114](https://github.com/cubed-dev/cubed/issues/114) |
62+
| | `flip` | :white_check_mark: | | |
6363
| | `permute_dims` | :white_check_mark: | | |
6464
| | `repeat` | :x: | 2023.12 | |
6565
| | `reshape` | :white_check_mark: | | Partial implementation |

cubed/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@
279279
broadcast_to,
280280
concat,
281281
expand_dims,
282+
flip,
282283
moveaxis,
283284
permute_dims,
284285
reshape,
@@ -292,6 +293,7 @@
292293
"broadcast_to",
293294
"concat",
294295
"expand_dims",
296+
"flip",
295297
"moveaxis",
296298
"permute_dims",
297299
"reshape",

cubed/array_api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@
221221
broadcast_to,
222222
concat,
223223
expand_dims,
224+
flip,
224225
moveaxis,
225226
permute_dims,
226227
reshape,
@@ -234,6 +235,7 @@
234235
"broadcast_to",
235236
"concat",
236237
"expand_dims",
238+
"flip",
237239
"moveaxis",
238240
"permute_dims",
239241
"reshape",

cubed/array_api/manipulation_functions.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,50 @@ def flatten(x):
172172
return reshape(x, (-1,))
173173

174174

175+
def flip(x, /, *, axis=None):
176+
if axis is None:
177+
axis = tuple(range(x.ndim)) # all axes
178+
if not isinstance(axis, tuple):
179+
axis = (axis,)
180+
axis = validate_axis(axis, x.ndim)
181+
return map_direct(
182+
_flip,
183+
x,
184+
shape=x.shape,
185+
dtype=x.dtype,
186+
chunks=x.chunks,
187+
extra_projected_mem=x.chunkmem,
188+
target_chunks=x.chunks,
189+
axis=axis,
190+
)
191+
192+
193+
def _flip(x, *arrays, target_chunks=None, axis=None, block_id=None):
194+
array = arrays[0].zarray # underlying Zarr array (or virtual array)
195+
chunks = target_chunks
196+
197+
# produce a key that has slices (except for axis dimensions, which are replaced below)
198+
idx = tuple(0 if i == axis else v for i, v in enumerate(block_id))
199+
key = list(get_item(chunks, idx))
200+
201+
for ax in axis:
202+
# determine the start and stop indexes for this block along the axis dimension
203+
chunksize = to_chunksize(chunks)
204+
start = block_id[ax] * chunksize[ax]
205+
stop = start + x.shape[ax]
206+
207+
# flip start and stop
208+
axis_len = array.shape[ax]
209+
start, stop = axis_len - stop, axis_len - start
210+
211+
# replace with slice
212+
key[ax] = slice(start, stop)
213+
214+
key = tuple(key)
215+
216+
return nxp.flip(array[key], axis=axis)
217+
218+
175219
def moveaxis(
176220
x,
177221
source,

cubed/tests/test_array_api.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,31 @@ def test_expand_dims(spec, executor):
492492
assert_array_equal(b.compute(executor=executor), np.expand_dims([1, 2, 3], 0))
493493

494494

495+
@pytest.mark.parametrize(
496+
"shape, chunks, axis",
497+
[
498+
((10,), (4,), None),
499+
((10,), (4,), 0),
500+
((10, 7), (4, 3), None),
501+
((10, 7), (4, 3), 0),
502+
((10, 7), (4, 3), 1),
503+
((10, 7), (4, 3), (0, 1)),
504+
((10, 7), (4, 3), -1),
505+
],
506+
)
507+
def test_flip(executor, shape, chunks, axis):
508+
x = np.random.randint(10, size=shape)
509+
a = xp.asarray(x, chunks=chunks)
510+
b = xp.flip(a, axis=axis)
511+
512+
assert b.chunks == a.chunks
513+
514+
assert_array_equal(
515+
b.compute(executor=executor),
516+
np.flip(x, axis=axis),
517+
)
518+
519+
495520
def test_moveaxis(spec):
496521
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
497522
b = xp.moveaxis(a, [0, -1], [-1, 0])

0 commit comments

Comments
 (0)