Skip to content

Commit 26aac45

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add logical and op (pytorch#13342)
Summary: As title, add the logical and op for an internal model Rollback Plan: Differential Revision: D80122607
1 parent 43bd889 commit 26aac45

File tree

4 files changed

+26
-1
lines changed

4 files changed

+26
-1
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class LayoutTransform(ExportPass):
9292
exir_ops.edge.aten.le.Tensor,
9393
exir_ops.edge.aten.linear.default,
9494
exir_ops.edge.aten.log.default,
95+
exir_ops.edge.aten.logical_and.default,
9596
exir_ops.edge.aten.logical_not.default,
9697
exir_ops.edge.aten.lt.Scalar,
9798
exir_ops.edge.aten.lt.Tensor,

backends/qualcomm/quantizer/annotators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non
828828
)
829829

830830

831-
@register_annotator([torch.ops.aten.__and__.Tensor])
831+
@register_annotator([torch.ops.aten.__and__.Tensor, torch.ops.aten.logical_and.default])
832832
def annotate_and(node: Node, quantization_config: QuantizationConfig) -> None:
833833
annotate_binary(node, quantization_config)
834834

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,14 @@ def forward(self, x):
11711171
return torch.log(x)
11721172

11731173

1174+
class LogicalAnd(torch.nn.Module):
1175+
def __init__(self):
1176+
super().__init__()
1177+
1178+
def forward(self, x, y):
1179+
return torch.logical_and(x, y)
1180+
1181+
11741182
class LogicalNot(torch.nn.Module):
11751183
def __init__(self):
11761184
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,14 @@ def test_qnn_backend_log(self):
910910
sample_input = (torch.rand([1, 2, 3, 4]),)
911911
self.lower_module_and_test_output(module, sample_input)
912912

913+
def test_qnn_backend_logical_and(self):
914+
module = LogicalAnd() # noqa: F405
915+
input1 = torch.tensor([True, False, True, False])
916+
input2 = torch.tensor([True, True, False, False])
917+
sample_input = (input1, input2)
918+
self.lower_module_and_test_output(module, sample_input)
919+
920+
913921
def test_qnn_backend_logical_not(self):
914922
module = LogicalNot() # noqa: F405
915923
sample_input = (torch.rand([1, 2, 3, 4]),)
@@ -2443,6 +2451,14 @@ def test_qnn_backend_log(self):
24432451
module = self.get_qdq_module(module, sample_input)
24442452
self.lower_module_and_test_output(module, sample_input)
24452453

2454+
def test_qnn_backend_logical_and(self):
2455+
module = LogicalAnd() # noqa: F405
2456+
input1 = torch.tensor([True, False, True, False])
2457+
input2 = torch.tensor([True, True, False, False])
2458+
sample_input = (input1, input2)
2459+
module = self.get_qdq_module(module, sample_input)
2460+
self.lower_module_and_test_output(module, sample_input)
2461+
24462462
def test_qnn_backend_logical_not(self):
24472463
module = LogicalNot() # noqa: F405
24482464
sample_input = (torch.rand([1, 2, 3, 4]),)

0 commit comments

Comments
 (0)