Skip to content

Commit 51ec1a5

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Update add tests (#259)
Summary: Pull Request resolved: #259 Reviewed By: mcr229 Differential Revision: D48921118 fbshipit-source-id: 638be62041a71c389d9124597ba4a085efa9b39d
1 parent 0affcf2 commit 51ec1a5

File tree

1 file changed

+43
-24
lines changed
  • backends/xnnpack/test/ops

1 file changed

+43
-24
lines changed

backends/xnnpack/test/ops/add.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from executorch.backends.xnnpack.test.tester import Partition, Tester
1414

1515

16-
class TestXNNPACKAdd(unittest.TestCase):
17-
class AddModule(torch.nn.Module):
16+
class TestAdd(unittest.TestCase):
17+
class Add(torch.nn.Module):
1818
def __init__(self):
1919
super().__init__()
2020

@@ -25,15 +25,10 @@ def forward(self, x, y):
2525
z = z + z
2626
return z
2727

28-
def test_add(self):
29-
"""
30-
This test is the simplest test by manually lowering some submodules, we can use paritioner for auto detecting lowerable parts
31-
"""
32-
add_module = self.AddModule()
33-
model_inputs = (torch.ones(1), torch.ones(1))
34-
28+
def test_fp32_add(self):
29+
inputs = (torch.ones(1), torch.ones(1))
3530
(
36-
Tester(add_module, model_inputs)
31+
Tester(self.Add(), inputs)
3732
.export()
3833
.check_count({"torch.ops.aten.add.Tensor": 4})
3934
.to_edge()
@@ -47,16 +42,14 @@ def test_add(self):
4742
.compare_outputs()
4843
)
4944

50-
def test_add_quantized(self):
51-
add_module = self.AddModule()
52-
model_inputs = (torch.ones(1), torch.ones(1))
53-
45+
def test_qs8_add(self):
46+
inputs = (torch.ones(1), torch.ones(1))
5447
(
55-
Tester(add_module, model_inputs)
56-
.quantize()
57-
.check(["torch.ops.quantized_decomposed"])
48+
Tester(self.Add(), inputs)
49+
.quantize2()
5850
.export()
5951
.check_count({"torch.ops.aten.add.Tensor": 4})
52+
.check(["torch.ops.quantized_decomposed"])
6053
.to_edge()
6154
.check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4})
6255
.partition(Partition(partitioner=XnnpackQuantizedPartitioner))
@@ -69,22 +62,48 @@ def test_add_quantized(self):
6962
.compare_outputs()
7063
)
7164

72-
def test_add_quantized_pt2e(self):
73-
add_module = self.AddModule()
74-
model_inputs = (torch.ones(1), torch.ones(1))
65+
class AddRelu(torch.nn.Module):
66+
def forward(self, x, y):
67+
z = x + y
68+
return torch.nn.functional.relu(z)
7569

70+
def test_fp32_add_relu(self):
71+
inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
7672
(
77-
Tester(add_module, model_inputs)
73+
Tester(self.AddRelu(), inputs)
74+
.export()
75+
.check_count({"torch.ops.aten.add.Tensor": 1})
76+
.check_count({"torch.ops.aten.relu.default": 1})
77+
.to_edge()
78+
.check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1})
79+
.check_count({"executorch_exir_dialects_edge__ops_aten_relu_default": 1})
80+
.partition()
81+
.check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])
82+
.check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"])
83+
.check_count({"torch.ops.executorch_call_delegate": 1})
84+
.to_executorch()
85+
.serialize()
86+
.run_method()
87+
.compare_outputs()
88+
)
89+
90+
def test_qs8_add_relu(self):
91+
inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
92+
(
93+
Tester(self.AddRelu(), inputs)
7894
.quantize2()
7995
.export()
80-
.check_count({"torch.ops.aten.add.Tensor": 4})
96+
.check_count({"torch.ops.aten.add.Tensor": 1})
97+
.check_count({"torch.ops.aten.relu.default": 1})
8198
.check(["torch.ops.quantized_decomposed"])
8299
.to_edge()
83-
.check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4})
100+
.check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1})
101+
.check_count({"executorch_exir_dialects_edge__ops_aten_relu_default": 1})
84102
.partition(Partition(partitioner=XnnpackQuantizedPartitioner))
85-
.check_count({"torch.ops.executorch_call_delegate": 1})
86103
.check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"])
104+
.check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"])
87105
.check_not(["torch.ops.quantized_decomposed"])
106+
.check_count({"torch.ops.executorch_call_delegate": 1})
88107
.to_executorch()
89108
.serialize()
90109
.run_method()

0 commit comments

Comments
 (0)