@@ -681,6 +681,10 @@ def test_qnn_backend_view(self):
681681 sample_input = (torch .randn ([1 , 8 , 512 ]), torch .randn ([1 , 2 , 8 , 256 ]))
682682 self .lower_module_and_test_output (module , sample_input )
683683
684+ def test_qnn_backend_argmin (self ):
685+ module = Conv2dArgmin () # noqa: F405
686+ sample_input = (torch .randn (16 , 3 , 16 , 16 ),)
687+ self .lower_module_and_test_output (module , sample_input )
684688
685689class TestQNNFloatingPointModel (TestQNN ):
686690 # TODO: refactor to support different backends
@@ -704,7 +708,8 @@ def test_qnn_backend_chunk_add(self):
704708 torch .manual_seed (8 )
705709 sample_input = (torch .randn (1 , 2 , 4 , 2 ),)
706710 self .lower_module_and_test_output (module , sample_input )
707-
711+
712+
708713 def test_qnn_backend_conv1d_relu_log_softmax (self ):
709714 module = Conv1dReluLogSoftmax () # noqa: F405
710715 sample_input = (torch .rand (1 , 2 , 28 ),)
@@ -1585,6 +1590,12 @@ def test_qnn_backend_view(self):
15851590 sample_input = (torch .randn ([1 , 8 , 512 ]), torch .randn ([1 , 2 , 8 , 256 ]))
15861591 module = self .get_qdq_module (module , sample_input )
15871592 self .lower_module_and_test_output (module , sample_input )
1593+
1594+ def test_qnn_backend_argmin (self ):
1595+ module = Conv2dArgmin () # noqa: F405
1596+ sample_input = (torch .randn (16 , 3 , 16 , 16 ),)
1597+ module = self .get_qdq_module (module , sample_input )
1598+ self .lower_module_and_test_output (module , sample_input )
15881599
15891600
15901601class TestQNNQuantizedModel (TestQNN ):
@@ -1610,6 +1621,8 @@ def test_qnn_backend_chunk_add(self):
16101621 sample_input = (torch .randn (1 , 1 , 4 , 2 ),)
16111622 module = self .get_qdq_module (module , sample_input )
16121623 self .lower_module_and_test_output (module , sample_input )
1624+
1625+
16131626
16141627 def test_qnn_backend_conv1d_relu_log_softmax (self ):
16151628 module = Conv1dReluLogSoftmax () # noqa: F405
0 commit comments