Skip to content

Commit a5bd401

Browse files
authored
[API Compatiblity] Support paddle.index_add (PaddlePaddle#76170)
* support index_add * fix * fix * fix UT
1 parent c125f23 commit a5bd401

File tree

3 files changed

+206
-4
lines changed

3 files changed

+206
-4
lines changed

python/paddle/tensor/manipulation.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ParamAliasDecorator,
3131
VariableArgsDecorator,
3232
expand_decorator,
33+
index_add_decorator,
3334
param_one_alias,
3435
param_two_alias,
3536
reshape_decorator,
@@ -58,6 +59,7 @@
5859

5960
if TYPE_CHECKING:
6061
from collections.abc import Callable, Sequence
62+
from numbers import Number
6163

6264
from paddle import Tensor
6365
from paddle._typing import (
@@ -7599,18 +7601,31 @@ def scatter_add_(
75997601
)
76007602

76017603

7604+
@index_add_decorator()
76027605
def index_add(
7603-
x: Tensor, index: Tensor, axis: int, value: Tensor, name: str | None = None
7606+
x: Tensor,
7607+
index: Tensor,
7608+
axis: int,
7609+
value: Tensor,
7610+
alpha: Number = 1,
7611+
name: str | None = None,
7612+
*,
7613+
out: Tensor | None = None,
76047614
) -> Tensor:
76057615
"""
76067616
Adds the elements of the input tensor with value tensor by selecting the indices in the order given in index.
76077617
76087618
Args:
76097619
x (Tensor) : The Destination Tensor. Supported data types are int32, int64, float16, float32, float64.
7620+
alias: ``input``.
76107621
index (Tensor): The 1-D Tensor containing the indices to index.
76117622
The data type of ``index`` must be int32 or int64.
76127623
axis (int): The dimension in which we index.
7624+
alias: ``dim``.
76137625
value (Tensor): The tensor used to add the elements along the target axis.
7626+
alias: ``source``.
7627+
alpha (Number, optional): Scaling factor for value. Default: 1.
7628+
out (Tensor, optional): The output tensor. Default: None.
76147629
name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
76157630
76167631
Returns:
@@ -7634,7 +7649,8 @@ def index_add(
76347649
[2., 2., 2.]])
76357650
"""
76367651
if in_dynamic_or_pir_mode():
7637-
return _C_ops.index_add(x, index, value, axis)
7652+
scaled_value = value * alpha if alpha != 1 else value
7653+
return _C_ops.index_add(x, index, scaled_value, axis, out=out)
76387654

76397655
helper = LayerHelper("index_add", **locals())
76407656
check_variable_and_dtype(
@@ -7671,15 +7687,22 @@ def index_add(
76717687
return out
76727688

76737689

7690+
@index_add_decorator()
76747691
@inplace_apis_in_dygraph_only
76757692
def index_add_(
7676-
x: Tensor, index: Tensor, axis: int, value: Tensor, name: str | None = None
7693+
x: Tensor,
7694+
index: Tensor,
7695+
axis: int,
7696+
value: Tensor,
7697+
alpha: int = 1,
7698+
name: str | None = None,
76777699
) -> Tensor:
76787700
"""
76797701
Inplace version of ``index_add`` API, the output Tensor will be inplaced with input ``x``.
76807702
Please refer to :ref:`api_paddle_index_add`.
76817703
"""
7682-
return _C_ops.index_add_(x, index, value, axis)
7704+
scaled_value = value * alpha if alpha != 1 else value
7705+
return _C_ops.index_add_(x, index, scaled_value, axis)
76837706

76847707

76857708
def unflatten(

python/paddle/utils/decorator_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,3 +913,33 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
913913

914914
wrapper.__signature__ = inspect.signature(init_func)
915915
return wrapper
916+
917+
918+
def index_add_decorator() -> Callable[
919+
[Callable[_InputT, _RetT]], Callable[_InputT, _RetT]
920+
]:
921+
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
922+
@functools.wraps(func)
923+
def wrapper(*args, **kwargs) -> _RetT:
924+
if "input" in kwargs:
925+
kwargs["x"] = kwargs.pop("input")
926+
if "dim" in kwargs:
927+
kwargs["axis"] = kwargs.pop("dim")
928+
if "source" in kwargs:
929+
kwargs["value"] = kwargs.pop("source")
930+
931+
if len(args) >= 2 and isinstance(args[1], int):
932+
kwargs["x"] = args[0]
933+
kwargs["axis"] = args[1]
934+
if len(args) > 2:
935+
kwargs["index"] = args[2]
936+
if len(args) > 3:
937+
kwargs["value"] = args[3]
938+
args = args[4:]
939+
940+
return func(*args, **kwargs)
941+
942+
wrapper.__signature__ = inspect.signature(func)
943+
return wrapper
944+
945+
return decorator

test/legacy_test/test_index_add_op.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,5 +554,154 @@ def test_check_grad_normal(self):
554554
)
555555

556556

557+
def get_places():
558+
places = []
559+
if paddle.base.is_compiled_with_cuda() or is_custom_device():
560+
places.append(get_device_place())
561+
places.append(paddle.CPUPlace())
562+
return places
563+
564+
565+
class TestIndexAddAPI_Compatibility(unittest.TestCase):
566+
def setUp(self):
567+
np.random.seed(2025)
568+
self.places = get_places()
569+
self.shape = [10, 20]
570+
self.index_shape = [5]
571+
self.axis = 1
572+
self.dtype = "float32"
573+
self.value_shape = list(self.shape)
574+
self.value_shape[self.axis] = self.index_shape[0]
575+
self.init_data()
576+
577+
def init_data(self):
578+
self.np_input = np.random.rand(*self.shape).astype(self.dtype)
579+
self.np_index = np.random.randint(
580+
0, self.shape[self.axis], self.index_shape
581+
).astype("int64")
582+
self.np_value = np.random.rand(*self.value_shape).astype(self.dtype)
583+
584+
def get_ref_out(self, alpha=1.0):
585+
ref_out = np.copy(self.np_input)
586+
idx = [slice(None)] * len(self.shape)
587+
idx[self.axis] = self.np_index
588+
np.add.at(ref_out, tuple(idx), self.np_value * alpha)
589+
return ref_out
590+
591+
def test_dygraph_Compatibility(self):
592+
paddle.disable_static()
593+
x = paddle.to_tensor(self.np_input)
594+
index = paddle.to_tensor(self.np_index)
595+
value = paddle.to_tensor(self.np_value)
596+
paddle_dygraph_out = []
597+
598+
ref_out = self.get_ref_out()
599+
# 1. Position args (Paddle style: x, index, axis, value)
600+
out1 = paddle.index_add(x, index, self.axis, value)
601+
paddle_dygraph_out.append(out1)
602+
# 2. Key words args (kwargs) for paddle
603+
out2 = paddle.index_add(x=x, index=index, axis=self.axis, value=value)
604+
paddle_dygraph_out.append(out2)
605+
# 3. Key words args (kwargs) for torch
606+
out3 = paddle.index_add(
607+
input=x, dim=self.axis, index=index, source=value
608+
)
609+
paddle_dygraph_out.append(out3)
610+
# 4. PyTorch positional args order: (input, dim, index, source)
611+
out4 = paddle.index_add(x, self.axis, index, value)
612+
paddle_dygraph_out.append(out4)
613+
# 5. Tensor method args (Paddle style)
614+
out5 = x.index_add(index, self.axis, value)
615+
paddle_dygraph_out.append(out5)
616+
# 6. Tensor method kwargs (PyTorch style)
617+
out6 = x.index_add(dim=self.axis, index=index, source=value)
618+
paddle_dygraph_out.append(out6)
619+
# 7. Test 'out' parameter
620+
out7 = paddle.empty_like(x)
621+
paddle.index_add(
622+
input=x, dim=self.axis, index=index, source=value, out=out7
623+
)
624+
paddle_dygraph_out.append(out7)
625+
# 8. Test 'alpha' parameter
626+
alpha = 2.0
627+
out8 = paddle.index_add(
628+
input=x, dim=self.axis, index=index, source=value, alpha=alpha
629+
)
630+
out9 = paddle.index_add_(
631+
input=x, dim=self.axis, index=index, source=value, alpha=alpha
632+
)
633+
ref_out_alpha = self.get_ref_out(alpha=alpha)
634+
635+
for out in paddle_dygraph_out:
636+
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-05)
637+
np.testing.assert_allclose(ref_out_alpha, out8.numpy(), rtol=1e-05)
638+
np.testing.assert_allclose(ref_out_alpha, out9.numpy(), rtol=1e-05)
639+
paddle.enable_static()
640+
641+
def test_static_Compatibility(self):
642+
paddle.enable_static()
643+
main = paddle.static.Program()
644+
startup = paddle.static.Program()
645+
with paddle.base.program_guard(main, startup):
646+
x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype)
647+
index = paddle.static.data(
648+
name="index", shape=self.index_shape, dtype="int64"
649+
)
650+
value = paddle.static.data(
651+
name="value", shape=self.value_shape, dtype=self.dtype
652+
)
653+
# 1. Position args (Paddle style: x, index, axis, value)
654+
out1 = paddle.index_add(x, index, self.axis, value)
655+
# 2. Key words args (kwargs) for paddle
656+
out2 = paddle.index_add(
657+
x=x, index=index, axis=self.axis, value=value
658+
)
659+
# 3. Key words args (kwargs) for torch
660+
out3 = paddle.index_add(
661+
input=x, dim=self.axis, index=index, source=value
662+
)
663+
# 4. PyTorch positional args order: (input, dim, index, source)
664+
out4 = paddle.index_add(x, self.axis, index, value)
665+
# 5. Tensor method args (Paddle style)
666+
out5 = x.index_add(index, self.axis, value)
667+
# 6. Tensor method kwargs (PyTorch style)
668+
out6 = x.index_add(dim=self.axis, index=index, source=value)
669+
# 7. Test 'alpha' parameter
670+
alpha = 2.0
671+
out7 = paddle.index_add(
672+
input=x, dim=self.axis, index=index, source=value, alpha=alpha
673+
)
674+
ref_out = self.get_ref_out()
675+
ref_out_alpha = self.get_ref_out(alpha=alpha)
676+
677+
fetch_list = [
678+
out1,
679+
out2,
680+
out3,
681+
out4,
682+
out5,
683+
out6,
684+
out7,
685+
]
686+
feed_dict = {
687+
"x": self.np_input,
688+
"index": self.np_index,
689+
"value": self.np_value,
690+
}
691+
692+
for place in self.places:
693+
exe = paddle.base.Executor(place)
694+
fetches = exe.run(
695+
main,
696+
feed=feed_dict,
697+
fetch_list=fetch_list,
698+
)
699+
for out in fetches[:-1]:
700+
np.testing.assert_allclose(out, ref_out, rtol=1e-05)
701+
np.testing.assert_allclose(
702+
fetches[-1], ref_out_alpha, rtol=1e-05
703+
)
704+
705+
557706
if __name__ == '__main__':
558707
unittest.main()

0 commit comments

Comments
 (0)