@@ -517,6 +517,11 @@ def test_qnn_backend_log(self):
517517 sample_input = (torch .rand ([1 , 2 , 3 , 4 ]),)
518518 self .lower_module_and_test_output (module , sample_input )
519519
520+ def test_qnn_backend_logical_not (self ):
521+ module = LogicalNot () # noqa: F405
522+ sample_input = (torch .rand ([1 , 2 , 3 , 4 ]),)
523+ self .lower_module_and_test_output (module , sample_input )
524+
520525 def test_qnn_backend_log_softmax (self ):
521526 module = LogSoftmax () # noqa: F405
522527 sample_input = (torch .randn ([1 , 4 , 8 , 8 ]),)
@@ -696,6 +701,18 @@ def test_qnn_backend_view(self):
696701 sample_input = (torch .randn ([1 , 8 , 512 ]), torch .randn ([1 , 2 , 8 , 256 ]))
697702 self .lower_module_and_test_output (module , sample_input )
698703
704+ def test_qnn_backend_where (self ):
705+ modules = [
706+ Where (), # noqa: F405
707+ WhereConstant (torch .randn (3 , 2 ), torch .randn (3 , 2 )), # noqa: F405
708+ ]
709+ sample_inputs = [
710+ (torch .randn (3 , 2 ), torch .randn (3 , 2 ), torch .randn (3 , 2 )),
711+ (torch .randn (3 , 2 ),),
712+ ]
713+ for i , module in enumerate (modules ):
714+ self .lower_module_and_test_output (module , sample_inputs [i ])
715+
699716
700717class TestQNNFloatingPointModel (TestQNN ):
701718 # TODO: refactor to support different backends
@@ -1400,6 +1417,12 @@ def test_qnn_backend_log(self):
14001417 module = self .get_qdq_module (module , sample_input )
14011418 self .lower_module_and_test_output (module , sample_input )
14021419
1420+ def test_qnn_backend_logical_not (self ):
1421+ module = LogicalNot () # noqa: F405
1422+ sample_input = (torch .rand ([1 , 2 , 3 , 4 ]),)
1423+ module = self .get_qdq_module (module , sample_input )
1424+ self .lower_module_and_test_output (module , sample_input )
1425+
14031426 def test_qnn_backend_log_softmax (self ):
14041427 module = LogSoftmax () # noqa: F405
14051428 sample_input = (torch .randn ([1 , 4 , 8 , 8 ]),)
@@ -1613,6 +1636,19 @@ def test_qnn_backend_view(self):
16131636 module = self .get_qdq_module (module , sample_input )
16141637 self .lower_module_and_test_output (module , sample_input )
16151638
1639+ def test_qnn_backend_where (self ):
1640+ modules = [
1641+ Where (), # noqa: F405
1642+ WhereConstant (torch .randn (3 , 2 ), torch .randn (3 , 2 )), # noqa: F405
1643+ ]
1644+ sample_inputs = [
1645+ (torch .randn (3 , 2 ), torch .randn (3 , 2 ), torch .randn (3 , 2 )),
1646+ (torch .randn (3 , 2 ),),
1647+ ]
1648+ for i , module in enumerate (modules ):
1649+ module = self .get_qdq_module (module , sample_inputs [i ])
1650+ self .lower_module_and_test_output (module , sample_inputs [i ])
1651+
16161652
16171653class TestQNNQuantizedModel (TestQNN ):
16181654 # TODO: refactor to support different backends
0 commit comments