@@ -513,6 +513,11 @@ def test_qnn_backend_log(self):
513513 sample_input = (torch .rand ([1 , 2 , 3 , 4 ]),)
514514 self .lower_module_and_test_output (module , sample_input )
515515
516+ def test_qnn_backend_logical_not (self ):
517+ module = LogicalNot () # noqa: F405
518+ sample_input = (torch .rand ([1 , 2 , 3 , 4 ]),)
519+ self .lower_module_and_test_output (module , sample_input )
520+
516521 def test_qnn_backend_log_softmax (self ):
517522 module = LogSoftmax () # noqa: F405
518523 sample_input = (torch .randn ([1 , 4 , 8 , 8 ]),)
@@ -692,6 +697,18 @@ def test_qnn_backend_view(self):
692697 sample_input = (torch .randn ([1 , 8 , 512 ]), torch .randn ([1 , 2 , 8 , 256 ]))
693698 self .lower_module_and_test_output (module , sample_input )
694699
700+ def test_qnn_backend_where (self ):
701+ modules = [
702+ Where (), # noqa: F405
703+ WhereConstant (torch .randn (3 , 2 ), torch .randn (3 , 2 )), # noqa: F405
704+ ]
705+ sample_inputs = [
706+ (torch .randn (3 , 2 ), torch .randn (3 , 2 ), torch .randn (3 , 2 )),
707+ (torch .randn (3 , 2 ),),
708+ ]
709+ for i , module in enumerate (modules ):
710+ self .lower_module_and_test_output (module , sample_inputs [i ])
711+
695712
696713class TestQNNFloatingPointModel (TestQNN ):
697714 # TODO: refactor to support different backends
@@ -1396,6 +1413,12 @@ def test_qnn_backend_log(self):
13961413 module = self .get_qdq_module (module , sample_input )
13971414 self .lower_module_and_test_output (module , sample_input )
13981415
1416+ def test_qnn_backend_logical_not (self ):
1417+ module = LogicalNot () # noqa: F405
1418+ sample_input = (torch .rand ([1 , 2 , 3 , 4 ]),)
1419+ module = self .get_qdq_module (module , sample_input )
1420+ self .lower_module_and_test_output (module , sample_input )
1421+
13991422 def test_qnn_backend_log_softmax (self ):
14001423 module = LogSoftmax () # noqa: F405
14011424 sample_input = (torch .randn ([1 , 4 , 8 , 8 ]),)
@@ -1609,6 +1632,19 @@ def test_qnn_backend_view(self):
16091632 module = self .get_qdq_module (module , sample_input )
16101633 self .lower_module_and_test_output (module , sample_input )
16111634
1635+ def test_qnn_backend_where (self ):
1636+ modules = [
1637+ Where (), # noqa: F405
1638+ WhereConstant (torch .randn (3 , 2 ), torch .randn (3 , 2 )), # noqa: F405
1639+ ]
1640+ sample_inputs = [
1641+ (torch .randn (3 , 2 ), torch .randn (3 , 2 ), torch .randn (3 , 2 )),
1642+ (torch .randn (3 , 2 ),),
1643+ ]
1644+ for i , module in enumerate (modules ):
1645+ module = self .get_qdq_module (module , sample_inputs [i ])
1646+ self .lower_module_and_test_output (module , sample_inputs [i ])
1647+
16121648
16131649class TestQNNQuantizedModel (TestQNN ):
16141650 # TODO: refactor to support different backends
0 commit comments