@@ -554,5 +554,154 @@ def test_check_grad_normal(self):
554554 )
555555
556556
557+ def get_places ():
558+ places = []
559+ if paddle .base .is_compiled_with_cuda () or is_custom_device ():
560+ places .append (get_device_place ())
561+ places .append (paddle .CPUPlace ())
562+ return places
563+
564+
565+ class TestIndexAddAPI_Compatibility (unittest .TestCase ):
566+ def setUp (self ):
567+ np .random .seed (2025 )
568+ self .places = get_places ()
569+ self .shape = [10 , 20 ]
570+ self .index_shape = [5 ]
571+ self .axis = 1
572+ self .dtype = "float32"
573+ self .value_shape = list (self .shape )
574+ self .value_shape [self .axis ] = self .index_shape [0 ]
575+ self .init_data ()
576+
577+ def init_data (self ):
578+ self .np_input = np .random .rand (* self .shape ).astype (self .dtype )
579+ self .np_index = np .random .randint (
580+ 0 , self .shape [self .axis ], self .index_shape
581+ ).astype ("int64" )
582+ self .np_value = np .random .rand (* self .value_shape ).astype (self .dtype )
583+
584+ def get_ref_out (self , alpha = 1.0 ):
585+ ref_out = np .copy (self .np_input )
586+ idx = [slice (None )] * len (self .shape )
587+ idx [self .axis ] = self .np_index
588+ np .add .at (ref_out , tuple (idx ), self .np_value * alpha )
589+ return ref_out
590+
591+ def test_dygraph_Compatibility (self ):
592+ paddle .disable_static ()
593+ x = paddle .to_tensor (self .np_input )
594+ index = paddle .to_tensor (self .np_index )
595+ value = paddle .to_tensor (self .np_value )
596+ paddle_dygraph_out = []
597+
598+ ref_out = self .get_ref_out ()
599+ # 1. Position args (Paddle style: x, index, axis, value)
600+ out1 = paddle .index_add (x , index , self .axis , value )
601+ paddle_dygraph_out .append (out1 )
602+ # 2. Key words args (kwargs) for paddle
603+ out2 = paddle .index_add (x = x , index = index , axis = self .axis , value = value )
604+ paddle_dygraph_out .append (out2 )
605+ # 3. Key words args (kwargs) for torch
606+ out3 = paddle .index_add (
607+ input = x , dim = self .axis , index = index , source = value
608+ )
609+ paddle_dygraph_out .append (out3 )
610+ # 4. PyTorch positional args order: (input, dim, index, source)
611+ out4 = paddle .index_add (x , self .axis , index , value )
612+ paddle_dygraph_out .append (out4 )
613+ # 5. Tensor method args (Paddle style)
614+ out5 = x .index_add (index , self .axis , value )
615+ paddle_dygraph_out .append (out5 )
616+ # 6. Tensor method kwargs (PyTorch style)
617+ out6 = x .index_add (dim = self .axis , index = index , source = value )
618+ paddle_dygraph_out .append (out6 )
619+ # 7. Test 'out' parameter
620+ out7 = paddle .empty_like (x )
621+ paddle .index_add (
622+ input = x , dim = self .axis , index = index , source = value , out = out7
623+ )
624+ paddle_dygraph_out .append (out7 )
625+ # 8. Test 'alpha' parameter
626+ alpha = 2.0
627+ out8 = paddle .index_add (
628+ input = x , dim = self .axis , index = index , source = value , alpha = alpha
629+ )
630+ out9 = paddle .index_add_ (
631+ input = x , dim = self .axis , index = index , source = value , alpha = alpha
632+ )
633+ ref_out_alpha = self .get_ref_out (alpha = alpha )
634+
635+ for out in paddle_dygraph_out :
636+ np .testing .assert_allclose (ref_out , out .numpy (), rtol = 1e-05 )
637+ np .testing .assert_allclose (ref_out_alpha , out8 .numpy (), rtol = 1e-05 )
638+ np .testing .assert_allclose (ref_out_alpha , out9 .numpy (), rtol = 1e-05 )
639+ paddle .enable_static ()
640+
641+ def test_static_Compatibility (self ):
642+ paddle .enable_static ()
643+ main = paddle .static .Program ()
644+ startup = paddle .static .Program ()
645+ with paddle .base .program_guard (main , startup ):
646+ x = paddle .static .data (name = "x" , shape = self .shape , dtype = self .dtype )
647+ index = paddle .static .data (
648+ name = "index" , shape = self .index_shape , dtype = "int64"
649+ )
650+ value = paddle .static .data (
651+ name = "value" , shape = self .value_shape , dtype = self .dtype
652+ )
653+ # 1. Position args (Paddle style: x, index, axis, value)
654+ out1 = paddle .index_add (x , index , self .axis , value )
655+ # 2. Key words args (kwargs) for paddle
656+ out2 = paddle .index_add (
657+ x = x , index = index , axis = self .axis , value = value
658+ )
659+ # 3. Key words args (kwargs) for torch
660+ out3 = paddle .index_add (
661+ input = x , dim = self .axis , index = index , source = value
662+ )
663+ # 4. PyTorch positional args order: (input, dim, index, source)
664+ out4 = paddle .index_add (x , self .axis , index , value )
665+ # 5. Tensor method args (Paddle style)
666+ out5 = x .index_add (index , self .axis , value )
667+ # 6. Tensor method kwargs (PyTorch style)
668+ out6 = x .index_add (dim = self .axis , index = index , source = value )
669+ # 7. Test 'alpha' parameter
670+ alpha = 2.0
671+ out7 = paddle .index_add (
672+ input = x , dim = self .axis , index = index , source = value , alpha = alpha
673+ )
674+ ref_out = self .get_ref_out ()
675+ ref_out_alpha = self .get_ref_out (alpha = alpha )
676+
677+ fetch_list = [
678+ out1 ,
679+ out2 ,
680+ out3 ,
681+ out4 ,
682+ out5 ,
683+ out6 ,
684+ out7 ,
685+ ]
686+ feed_dict = {
687+ "x" : self .np_input ,
688+ "index" : self .np_index ,
689+ "value" : self .np_value ,
690+ }
691+
692+ for place in self .places :
693+ exe = paddle .base .Executor (place )
694+ fetches = exe .run (
695+ main ,
696+ feed = feed_dict ,
697+ fetch_list = fetch_list ,
698+ )
699+ for out in fetches [:- 1 ]:
700+ np .testing .assert_allclose (out , ref_out , rtol = 1e-05 )
701+ np .testing .assert_allclose (
702+ fetches [- 1 ], ref_out_alpha , rtol = 1e-05
703+ )
704+
705+
557706if __name__ == '__main__' :
558707 unittest .main ()
0 commit comments