@@ -555,6 +555,11 @@ def test_qnn_backend_expm1(self):
555555 module = ExpM1 () # noqa: F405
556556 self .lower_module_and_test_output (module , sample_input )
557557
558+ def test_qnn_backend_flip (self ):
559+ sample_input = (torch .randn (3 , 4 , 5 , 6 ),)
560+ module = Flip () # noqa: F405
561+ self .lower_module_and_test_output (module , sample_input )
562+
558563 def test_qnn_backend_floor (self ):
559564 sample_input = (torch .randn (3 , 4 ),)
560565 module = Floor () # noqa: F405
@@ -778,6 +783,14 @@ def test_qnn_backend_index_put(self):
778783 skip_mutable_buffer = test [QCOM_MODULE ].skip_mutable_buffer ,
779784 )
780785
786+ def test_qnn_backend_index_select (self ):
787+ module = IndexSelect (dim = 1 ) # noqa: F405
788+ sample_input = (
789+ torch .randn (2 , 3 , 4 , 5 ),
790+ torch .tensor ([0 , 2 ]),
791+ )
792+ self .lower_module_and_test_output (module , sample_input )
793+
781794 def test_qnn_backend_instance_norm_2d (self ):
782795 modules = [InstanceNorm2d (32 ), InstanceNorm2d (32 , affine = False )] # noqa: F405
783796 sample_input = (torch .randn ([4 , 32 , 16 , 16 ]),)
@@ -2031,17 +2044,11 @@ def test_qnn_backend_expm1(self):
20312044 self .lower_module_and_test_output (module , sample_input )
20322045
20332046 def test_qnn_backend_flip (self ):
2034- sample_input = (torch .randn (3 , 4 , 5 ,6 ),)
2035- # golden_module = Flip()
2036- decomp_module = FlipDecomp ()
2037- decomp_module = self .get_qdq_module (decomp_module , sample_input )
2038- self .lower_module_and_test_output (decomp_module , sample_input )
2039- # golden_out = golden_module(sample_input)
2040- # decomp_out = decomp_module(sample_input)
2041- # torch.testing.assert_close(golden_out, decomp_out)
2042-
2043-
2044-
2047+ sample_input = (torch .randn (3 , 4 , 5 , 6 ),)
2048+ module = Flip () # noqa: F405
2049+ module = self .get_qdq_module (module , sample_input )
2050+ self .lower_module_and_test_output (module , sample_input )
2051+
20452052 def test_qnn_backend_floor (self ):
20462053 sample_input = (torch .randn (3 , 4 ),)
20472054 module = Floor () # noqa: F405
@@ -2285,6 +2292,15 @@ def test_qnn_backend_index_put(self):
22852292 skip_mutable_buffer = test [QCOM_MODULE ].skip_mutable_buffer ,
22862293 )
22872294
2295+ def test_qnn_backend_index_select (self ):
2296+ module = IndexSelect (dim = 1 ) # noqa: F405
2297+ sample_input = (
2298+ torch .randn (2 , 3 , 4 , 5 ),
2299+ torch .tensor ([0 , 2 ]),
2300+ )
2301+ module = self .get_qdq_module (module , sample_input )
2302+ self .lower_module_and_test_output (module , sample_input )
2303+
22882304 def test_qnn_backend_instance_norm_2d (self ):
22892305 modules = [InstanceNorm2d (32 ), InstanceNorm2d (32 , affine = False )] # noqa: F405
22902306 sample_input = (torch .randn ([4 , 32 , 16 , 16 ]),)
0 commit comments