@@ -310,6 +310,33 @@ def test_qnn_backend_element_wise_mul(self):
310310 self .lower_module_and_test_output (module , sample_input )
311311 index += 1
312312
313+ def test_qnn_backend_element_wise_or (self ):
314+ test_comb = [
315+ {
316+ QCOM_MODULE : OrBitWise ( # noqa: F405
317+ torch .tensor (1.7 ), torch .tensor (0.2 )
318+ ),
319+ QCOM_SAMPLE_INPUTS : (
320+ torch .tensor ([1 , 0 , 1 , 0 ], dtype = torch .bool ),
321+ torch .tensor ([1 , 1 , 0 , 0 ], dtype = torch .bool ),
322+ ),
323+ },
324+ {
325+ QCOM_MODULE : OrOperator ( # noqa: F405
326+ torch .tensor (1.5 ), torch .tensor (- 1.2 )
327+ ),
328+ QCOM_SAMPLE_INPUTS : (
329+ torch .full ((3 , 3 ), 1 ).triu (),
330+ torch .full ((3 , 3 ), 1 ).tril (diagonal = 0 ),
331+ ),
332+ },
333+ ]
334+ for i , test in enumerate (test_comb ):
335+ with self .subTest (i = i ):
336+ self .lower_module_and_test_output (
337+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
338+ )
339+
313340 def test_qnn_backend_element_wise_sqrt (self ):
314341 modules = [Sqrt (), SqrtConstant ()] # noqa: F405
315342 for i , module in enumerate (modules ):
@@ -1246,6 +1273,34 @@ def test_qnn_backend_element_wise_mul(self):
12461273 self .lower_module_and_test_output (module , sample_input )
12471274 index += 1
12481275
1276+ def test_qnn_backend_element_wise_or (self ):
1277+ test_comb = [
1278+ {
1279+ QCOM_MODULE : OrBitWise ( # noqa: F405
1280+ torch .tensor (1.7 ), torch .tensor (0.2 )
1281+ ),
1282+ QCOM_SAMPLE_INPUTS : (
1283+ torch .tensor ([1 , 0 , 1 , 0 ], dtype = torch .bool ),
1284+ torch .tensor ([1 , 1 , 0 , 0 ], dtype = torch .bool ),
1285+ ),
1286+ },
1287+ {
1288+ QCOM_MODULE : OrOperator ( # noqa: F405
1289+ torch .tensor (1.5 ), torch .tensor (- 1.2 )
1290+ ),
1291+ QCOM_SAMPLE_INPUTS : (
1292+ torch .full ((3 , 3 ), 1 ).triu (),
1293+ torch .full ((3 , 3 ), 1 ).tril (diagonal = 0 ),
1294+ ),
1295+ },
1296+ ]
1297+ for i , test in enumerate (test_comb ):
1298+ with self .subTest (i = i ):
1299+ module = self .get_qdq_module (
1300+ test [QCOM_MODULE ], test [QCOM_SAMPLE_INPUTS ]
1301+ )
1302+ self .lower_module_and_test_output (module , test [QCOM_SAMPLE_INPUTS ])
1303+
12491304 def test_qnn_backend_element_wise_sqrt (self ):
12501305 modules = [Sqrt (), SqrtConstant ()] # noqa: F405
12511306 for i , module in enumerate (modules ):
0 commit comments