1+ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+ # pyre-strict
4+
5+ from typing import Callable , List , Tuple
6+
7+ import torch
8+
9+ from executorch .backends .test .compliance_suite import (
10+ dtype_test ,
11+ operator_test ,
12+ OperatorTest ,
13+ )
14+
15+ class IndexPutModel (torch .nn .Module ):
16+ def __init__ (self , accumulate = False ):
17+ super ().__init__ ()
18+ self .accumulate = accumulate
19+
20+ def forward (self , x , indices , values ):
21+ # Clone the input to avoid modifying it in-place
22+ result = x .clone ()
23+ # Apply index_put_ and return the modified tensor
24+ result .index_put_ (indices , values , self .accumulate )
25+ return result
26+
27+ @operator_test
28+ class TestIndexPut (OperatorTest ):
29+ @dtype_test
30+ def test_index_put_dtype (self , dtype , tester_factory : Callable ) -> None :
31+ # Test with different dtypes
32+ indices = (torch .tensor ([0 , 2 ]),)
33+ values = torch .tensor ([10.0 , 20.0 ]).to (dtype )
34+ model = IndexPutModel ()
35+ self ._test_op (model , ((torch .rand (5 , 2 ) * 100 ).to (dtype ), indices , values ), tester_factory , use_random_test_inputs = False )
36+
37+ def test_index_put_basic (self , tester_factory : Callable ) -> None :
38+ # Basic test with default parameters
39+ indices = (torch .tensor ([0 , 2 ]),)
40+ values = torch .tensor ([10.0 , 20.0 ])
41+ self ._test_op (IndexPutModel (), (torch .randn (5 , 2 ), indices , values ), tester_factory , use_random_test_inputs = False )
42+
43+ def test_index_put_accumulate (self , tester_factory : Callable ) -> None :
44+ # Test with accumulate=True and accumulate=False
45+
46+ # Without accumulation (replace values)
47+ indices = (torch .tensor ([0 , 2 ]),)
48+ values = torch .tensor ([10.0 , 20.0 ])
49+ self ._test_op (IndexPutModel (accumulate = False ),
50+ (torch .ones (5 , 2 ), indices , values ), tester_factory , use_random_test_inputs = False )
51+
52+ # With accumulation (add values)
53+ indices = (torch .tensor ([0 , 2 ]),)
54+ values = torch .tensor ([10.0 , 20.0 ])
55+ self ._test_op (IndexPutModel (accumulate = True ),
56+ (torch .ones (5 , 2 ), indices , values ), tester_factory , use_random_test_inputs = False )
57+
58+ def test_index_put_shapes (self , tester_factory : Callable ) -> None :
59+ # Test with different tensor shapes
60+
61+ # 1D tensor
62+ indices = (torch .tensor ([0 , 2 ]),)
63+ values = torch .tensor ([10.0 , 20.0 ])
64+ self ._test_op (IndexPutModel (),
65+ (torch .randn (5 ), indices , values ), tester_factory , use_random_test_inputs = False )
66+
67+ # 2D tensor
68+ indices = (torch .tensor ([0 , 2 ]), torch .tensor ([1 , 1 ]))
69+ values = torch .tensor ([10.0 , 20.0 ])
70+ self ._test_op (IndexPutModel (),
71+ (torch .randn (5 , 2 ), indices , values ), tester_factory , use_random_test_inputs = False )
72+
73+ # 3D tensor
74+ indices = (torch .tensor ([0 , 2 ]), torch .tensor ([1 , 1 ]), torch .tensor ([0 , 1 ]))
75+ values = torch .tensor ([10.0 , 20.0 ])
76+ self ._test_op (IndexPutModel (),
77+ (torch .randn (5 , 3 , 2 ), indices , values ), tester_factory , use_random_test_inputs = False )
78+
79+ # 4D tensor
80+ indices = (torch .tensor ([0 , 2 ]), torch .tensor ([1 , 1 ]),
81+ torch .tensor ([0 , 1 ]), torch .tensor ([2 , 3 ]))
82+ values = torch .tensor ([10.0 ,])
83+ self ._test_op (IndexPutModel (),
84+ (torch .randn (5 , 3 , 2 , 4 ), indices , values ), tester_factory , use_random_test_inputs = False )
85+
86+ def test_index_put_indices (self , tester_factory : Callable ) -> None :
87+ # Test with different index patterns
88+
89+ # Single index
90+ indices = (torch .tensor ([2 ]),)
91+ values = torch .tensor ([10.0 ])
92+ self ._test_op (IndexPutModel (),
93+ (torch .randn (5 , 2 ), indices , values ), tester_factory , use_random_test_inputs = False )
94+
95+ # Multiple indices
96+ indices = (torch .tensor ([0 , 2 , 4 ]),)
97+ values = torch .tensor ([10.0 , 20.0 , 30.0 ])
98+ self ._test_op (IndexPutModel (),
99+ (torch .randn (5 , 3 ), indices , values ), tester_factory , use_random_test_inputs = False )
100+
101+ # Repeated indices with accumulate=True (values add up)
102+ indices = (torch .tensor ([1 , 1 , 3 , 3 ]),)
103+ values = torch .tensor ([10.0 , 20.0 , 30.0 , 40.0 ])
104+ self ._test_op (IndexPutModel (accumulate = True ),
105+ (torch .randn (5 ), indices , values ), tester_factory , use_random_test_inputs = False )
106+
107+ def test_index_put_edge_cases (self , tester_factory : Callable ) -> None :
108+ # Test edge cases
109+
110+ # Put values in all positions
111+ indices = (torch .tensor ([0 , 1 , 2 , 3 , 4 ]),)
112+ values = torch .tensor ([10.0 , 20.0 , 30.0 , 40.0 , 50.0 ])
113+ self ._test_op (IndexPutModel (),
114+ (torch .randn (5 , 5 ), indices , values ), tester_factory , use_random_test_inputs = False )
115+
0 commit comments