@@ -478,6 +478,13 @@ def test_qnn_backend_full_like(self):
478478 sample_input = (torch .randn (1 , 2 , 3 , 4 ),)
479479 self .lower_module_and_test_output (module , sample_input )
480480
481+ def test_qnn_backend_gather (self ):
482+ module = Gather () # noqa: F405
483+ shape = (2 , 2 , 3 , 4 )
484+ sample_input = (torch .randn (shape ), torch .randn (shape ))
485+ module = self .get_qdq_module (module , sample_input )
486+ self .lower_module_and_test_output (module , sample_input )
487+
481488 def test_qnn_backend_gelu (self ):
482489 module = Gelu () # noqa: F405
483490 sample_input = (torch .randn (2 , 5 , 1 , 3 ),)
@@ -821,12 +828,17 @@ def test_qnn_backend_select_copy(self):
821828 self .lower_module_and_test_output (module , sample_input )
822829
823830 def test_qnn_backend_slice_copy (self ):
824- modules = [SliceCopy (), SliceCopyWithStep ()] # noqa: F405
825- sample_input = (
826- torch .randn ([1 , 512 ]),
827- torch .randn ([1 , 8 ]),
828- )
829- for module in modules :
831+ modules = [
832+ SliceCopyDefaultParameter (),
833+ SliceCopy (),
834+ SliceCopyWithStep (),
835+ ] # noqa: F405
836+ sample_inputs = [
837+ (torch .randn ([2 , 1 , 320 , 512 ]),),
838+ (torch .randn ([1 , 512 ]), torch .randn ([1 , 8 ])),
839+ (torch .randn ([1 , 512 ]), torch .randn ([1 , 8 ])),
840+ ]
841+ for module , sample_input in zip (modules , sample_inputs ):
830842 self .lower_module_and_test_output (module , sample_input )
831843
832844 def test_qnn_backend_stack (self ):
@@ -1593,6 +1605,13 @@ def test_qnn_backend_full_like(self):
15931605 module = self .get_qdq_module (module , sample_input )
15941606 self .lower_module_and_test_output (module , sample_input )
15951607
1608+ def test_qnn_backend_gather (self ):
1609+ module = Gather () # noqa: F405
1610+ shape = (2 , 2 , 3 , 4 )
1611+ sample_input = (torch .randn (shape ), torch .randn (shape ))
1612+ module = self .get_qdq_module (module , sample_input )
1613+ self .lower_module_and_test_output (module , sample_input )
1614+
15961615 def test_qnn_backend_gelu (self ):
15971616 module = Gelu () # noqa: F405
15981617 sample_input = (torch .randn (2 , 5 , 1 , 3 ),)
@@ -1991,12 +2010,17 @@ def test_qnn_backend_sin(self):
19912010 self .lower_module_and_test_output (module , sample_input )
19922011
19932012 def test_qnn_backend_slice_copy (self ):
1994- modules = [SliceCopy (), SliceCopyWithStep ()] # noqa: F405
1995- sample_input = (
1996- torch .randn ([1 , 512 ]),
1997- torch .randn ([1 , 8 ]),
1998- )
1999- for module in modules :
2013+ modules = [
2014+ SliceCopyDefaultParameter (),
2015+ SliceCopy (),
2016+ SliceCopyWithStep (),
2017+ ] # noqa: F405
2018+ sample_inputs = [
2019+ (torch .randn ([2 , 1 , 320 , 512 ]),),
2020+ (torch .randn ([1 , 512 ]), torch .randn ([1 , 8 ])),
2021+ (torch .randn ([1 , 512 ]), torch .randn ([1 , 8 ])),
2022+ ]
2023+ for module , sample_input in zip (modules , sample_inputs ):
20002024 module = self .get_qdq_module (module , sample_input )
20012025 self .lower_module_and_test_output (module , sample_input )
20022026
0 commit comments