@@ -551,19 +551,47 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
551551
552552def pad (
553553 x : Array ,
554- pad_width : int ,
554+ pad_width : int | tuple [ int , int ] | list [ tuple [ int , int ]] ,
555555 * ,
556556 constant_values : bool | int | float | complex = 0 ,
557557 xp : ModuleType ,
558558) -> Array : # numpydoc ignore=PR01,RT01
559559 """See docstring in `array_api_extra._delegation.py`."""
560+ # make pad_width a list of length-2 tuples of ints
561+ x_ndim = cast (int , x .ndim )
562+ if isinstance (pad_width , int ):
563+ pad_width = [(pad_width , pad_width )] * x_ndim
564+ if isinstance (pad_width , tuple ):
565+ pad_width = [pad_width ] * x_ndim
566+
567+ # https://github.com/data-apis/array-api-extra/pull/82#discussion_r1905688819
568+ slices : list [slice ] = [] # type: ignore[no-any-explicit]
569+ newshape : list [int ] = []
570+ for ax , w_tpl in enumerate (pad_width ):
571+ if len (w_tpl ) != 2 :
572+ msg = f"expect a 2-tuple (before, after), got { w_tpl } ."
573+ raise ValueError (msg )
574+
575+ sh = x .shape [ax ]
576+ if w_tpl [0 ] == 0 and w_tpl [1 ] == 0 :
577+ sl = slice (None , None , None )
578+ else :
579+ start , stop = w_tpl
580+ stop = None if stop == 0 else - stop
581+
582+ sl = slice (start , stop , None )
583+ sh += w_tpl [0 ] + w_tpl [1 ]
584+
585+ newshape .append (sh )
586+ slices .append (sl )
587+
560588 padded = xp .full (
561- tuple (x + 2 * pad_width for x in x . shape ),
589+ tuple (newshape ),
562590 fill_value = constant_values ,
563591 dtype = x .dtype ,
564592 device = _compat .device (x ),
565593 )
566- padded [( slice ( pad_width , - pad_width , None ),) * x . ndim ] = x
594+ padded [tuple ( slices ) ] = x
567595 return padded
568596
569597
@@ -613,22 +641,39 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
613641
614642 Warnings
615643 --------
616- (a) When you omit the ``copy`` parameter, you should always immediately overwrite
617- the parameter array ::
644+ (a) When you omit the ``copy`` parameter, you should never reuse the parameter
645+ array later on; ideally, you should reassign it immediately ::
618646
619647 >>> import array_api_extra as xpx
620648 >>> x = xpx.at(x, 0).set(2)
621649
622- The anti-pattern below must be avoided, as it will result in different
623- behaviour on read-only versus writeable arrays::
650+ The above best practice pattern ensures that the behaviour won't change depending
651+ on whether ``x`` is writeable or not, as the original ``x`` object is dereferenced
652+ as soon as ``xpx.at`` returns; this way there is no risk to accidentally update it
653+ twice.
654+
655+ On the reverse, the anti-pattern below must be avoided, as it will result in
656+ different behaviour on read-only versus writeable arrays::
624657
625658 >>> x = xp.asarray([0, 0, 0])
626659 >>> y = xpx.at(x, 0).set(2)
627660 >>> z = xpx.at(x, 1).set(3)
628661
629- In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
630- when ``x`` is read-only, whereas ``x == y == z == [2, 3, 0]`` when ``x`` is
631- writeable!
662+ In the above example, both calls to ``xpx.at`` update ``x`` in place *if possible*.
663+ This causes the behaviour to diverge depending on whether ``x`` is writeable or not:
664+
665+ - If ``x`` is writeable, then after the snippet above you'll have
666+ ``x == y == z == [2, 3, 0]``
667+ - If ``x`` is read-only, then you'll end up with
668+ ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and ``z == [0, 3, 0]``.
669+
670+ The correct pattern to use if you want diverging outputs from the same input is
671+ to enforce copies::
672+
673+ >>> x = xp.asarray([0, 0, 0])
674+ >>> y = xpx.at(x, 0).set(2, copy=True) # Never updates x
675+ >>> z = xpx.at(x, 1).set(3) # May or may not update x in place
676+ >>> del x # avoid accidental reuse of x as we don't know its state anymore
632677
633678 (b) The array API standard does not support integer array indices.
634679 The behaviour of update methods when the index is an array of integers is
@@ -728,7 +773,7 @@ def _update_common(
728773 raise ValueError (msg )
729774
730775 if copy not in (True , False , None ):
731- msg = f"copy must be True, False, or None; got { copy !r} " # pyright: ignore[reportUnreachable]
776+ msg = f"copy must be True, False, or None; got { copy !r} "
732777 raise ValueError (msg )
733778
734779 if copy is None :
0 commit comments