Skip to content

Commit 44ec95a

Browse files
committed
Merge branch 'main' into pad-delegate
2 parents 486ebef + 70c22c0 commit 44ec95a

File tree

5 files changed

+122
-48
lines changed

5 files changed

+122
-48
lines changed

pixi.lock

Lines changed: 36 additions & 33 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,13 @@ files = ["src", "tests"]
204204
python_version = "3.10"
205205
warn_unused_configs = true
206206
strict = true
207-
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
208-
warn_unreachable = true
207+
enable_error_code = ["ignore-without-code", "truthy-bool"]
209208
disallow_untyped_defs = false
210209
disallow_incomplete_defs = false
211210
# data-apis/array-api#589
212211
disallow_any_expr = false
212+
# false positives with input validation
213+
disable_error_code = ["redundant-expr", "unreachable"]
213214

214215
[[tool.mypy.overrides]]
215216
module = "array_api_extra.*"
@@ -232,6 +233,8 @@ reportExplicitAny = false
232233
reportUnknownMemberType = false
233234
# no array-api-compat type stubs
234235
reportUnknownVariableType = false
236+
# false positives for input validation
237+
reportUnreachable = false
235238

236239

237240
# Ruff

src/array_api_extra/_delegation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _delegate(xp: ModuleType, *backends: IsNamespace) -> bool:
7373

7474
def pad(
7575
x: Array,
76-
pad_width: int,
76+
pad_width: int | tuple[int, int] | list[tuple[int, int]],
7777
mode: str = "constant",
7878
*,
7979
constant_values: bool | int | float | complex = 0,
@@ -86,8 +86,12 @@ def pad(
8686
----------
8787
x : array
8888
Input array.
89-
pad_width : int
89+
pad_width : int or tuple of ints or list of pairs of ints
9090
Pad the input array with this many elements from each side.
91+
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
92+
each pair applies to the corresponding axis of ``x``.
93+
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
94+
copies of this tuple.
9195
mode : str, optional
9296
Only "constant" mode is currently supported, which pads with
9397
the value passed to `constant_values`.

src/array_api_extra/_lib/_funcs.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -551,19 +551,47 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
551551

552552
def pad(
553553
x: Array,
554-
pad_width: int,
554+
pad_width: int | tuple[int, int] | list[tuple[int, int]],
555555
*,
556556
constant_values: bool | int | float | complex = 0,
557557
xp: ModuleType,
558558
) -> Array: # numpydoc ignore=PR01,RT01
559559
"""See docstring in `array_api_extra._delegation.py`."""
560+
# make pad_width a list of length-2 tuples of ints
561+
x_ndim = cast(int, x.ndim)
562+
if isinstance(pad_width, int):
563+
pad_width = [(pad_width, pad_width)] * x_ndim
564+
if isinstance(pad_width, tuple):
565+
pad_width = [pad_width] * x_ndim
566+
567+
# https://github.com/data-apis/array-api-extra/pull/82#discussion_r1905688819
568+
slices: list[slice] = [] # type: ignore[no-any-explicit]
569+
newshape: list[int] = []
570+
for ax, w_tpl in enumerate(pad_width):
571+
if len(w_tpl) != 2:
572+
msg = f"expect a 2-tuple (before, after), got {w_tpl}."
573+
raise ValueError(msg)
574+
575+
sh = x.shape[ax]
576+
if w_tpl[0] == 0 and w_tpl[1] == 0:
577+
sl = slice(None, None, None)
578+
else:
579+
start, stop = w_tpl
580+
stop = None if stop == 0 else -stop
581+
582+
sl = slice(start, stop, None)
583+
sh += w_tpl[0] + w_tpl[1]
584+
585+
newshape.append(sh)
586+
slices.append(sl)
587+
560588
padded = xp.full(
561-
tuple(x + 2 * pad_width for x in x.shape),
589+
tuple(newshape),
562590
fill_value=constant_values,
563591
dtype=x.dtype,
564592
device=_compat.device(x),
565593
)
566-
padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x
594+
padded[tuple(slices)] = x
567595
return padded
568596

569597

@@ -613,22 +641,39 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
613641
614642
Warnings
615643
--------
616-
(a) When you omit the ``copy`` parameter, you should always immediately overwrite
617-
the parameter array::
644+
(a) When you omit the ``copy`` parameter, you should never reuse the parameter
645+
array later on; ideally, you should reassign it immediately::
618646
619647
>>> import array_api_extra as xpx
620648
>>> x = xpx.at(x, 0).set(2)
621649
622-
The anti-pattern below must be avoided, as it will result in different
623-
behaviour on read-only versus writeable arrays::
650+
The above best practice pattern ensures that the behaviour won't change depending
651+
on whether ``x`` is writeable or not, as the original ``x`` object is dereferenced
652+
as soon as ``xpx.at`` returns; this way there is no risk to accidentally update it
653+
twice.
654+
655+
On the reverse, the anti-pattern below must be avoided, as it will result in
656+
different behaviour on read-only versus writeable arrays::
624657
625658
>>> x = xp.asarray([0, 0, 0])
626659
>>> y = xpx.at(x, 0).set(2)
627660
>>> z = xpx.at(x, 1).set(3)
628661
629-
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
630-
when ``x`` is read-only, whereas ``x == y == z == [2, 3, 0]`` when ``x`` is
631-
writeable!
662+
In the above example, both calls to ``xpx.at`` update ``x`` in place *if possible*.
663+
This causes the behaviour to diverge depending on whether ``x`` is writeable or not:
664+
665+
- If ``x`` is writeable, then after the snippet above you'll have
666+
``x == y == z == [2, 3, 0]``
667+
- If ``x`` is read-only, then you'll end up with
668+
``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and ``z == [0, 3, 0]``.
669+
670+
The correct pattern to use if you want diverging outputs from the same input is
671+
to enforce copies::
672+
673+
>>> x = xp.asarray([0, 0, 0])
674+
>>> y = xpx.at(x, 0).set(2, copy=True) # Never updates x
675+
>>> z = xpx.at(x, 1).set(3) # May or may not update x in place
676+
>>> del x # avoid accidental reuse of x as we don't know its state anymore
632677
633678
(b) The array API standard does not support integer array indices.
634679
The behaviour of update methods when the index is an array of integers is
@@ -728,7 +773,7 @@ def _update_common(
728773
raise ValueError(msg)
729774

730775
if copy not in (True, False, None):
731-
msg = f"copy must be True, False, or None; got {copy!r}" # pyright: ignore[reportUnreachable]
776+
msg = f"copy must be True, False, or None; got {copy!r}"
732777
raise ValueError(msg)
733778

734779
if copy is None:

tests/test_funcs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,22 @@ def test_device(self):
416416

417417
def test_xp(self):
418418
assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))
419+
420+
def test_tuple_width(self):
421+
a = xp.reshape(xp.arange(12), (3, 4))
422+
padded = pad(a, (1, 0))
423+
assert padded.shape == (4, 5)
424+
425+
padded = pad(a, (1, 2))
426+
assert padded.shape == (6, 7)
427+
428+
with pytest.raises(ValueError, match="expect a 2-tuple"):
429+
pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType]
430+
431+
def test_list_of_tuples_width(self):
432+
a = xp.reshape(xp.arange(12), (3, 4))
433+
padded = pad(a, [(1, 0), (0, 2)])
434+
assert padded.shape == (4, 6)
435+
436+
padded = pad(a, [(1, 0), (0, 0)])
437+
assert padded.shape == (4, 4)

0 commit comments

Comments
 (0)