Skip to content

Commit 0f6bae0

Browse files
mcr229facebook-github-bot
authored andcommitted
MobileNetv3 FP32 + QS8 (#142)
Summary: Pull Request resolved: #142 Tester for MobileNetv3 Reviewed By: digantdesai, guangy10, kirklandsign Differential Revision: D48667678 fbshipit-source-id: f5b8f94320ca6ddeb063a98e51dcb8bc05c3c4d6
1 parent 41125f9 commit 0f6bae0

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
import torchvision.models as models
11+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
12+
XnnpackQuantizedPartitioner2,
13+
)
14+
from executorch.backends.xnnpack.test.tester import Partition, Tester
15+
from executorch.backends.xnnpack.test.tester.tester import Export
16+
from executorch.backends.xnnpack.utils.configs import get_xnnpack_capture_config
17+
18+
19+
class TestMobileNetV3(unittest.TestCase):
20+
export_stage = Export(get_xnnpack_capture_config(enable_aot=True))
21+
22+
mv3 = models.mobilenetv3.mobilenet_v3_small(pretrained=True)
23+
mv3 = mv3.eval()
24+
model_inputs = (torch.ones(1, 3, 224, 244),)
25+
26+
all_operators = {
27+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
28+
"executorch_exir_dialects_edge__ops_aten_clamp_default",
29+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default",
30+
"executorch_exir_dialects_edge__ops_aten_addmm_default",
31+
"executorch_exir_dialects_edge__ops_aten__to_copy_default",
32+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
33+
"executorch_exir_dialects_edge__ops_aten_relu_default",
34+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
35+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
36+
"executorch_exir_dialects_edge__ops_aten_div_Tensor",
37+
"executorch_exir_dialects_edge__ops_aten_mean_dim",
38+
}
39+
40+
def test_fp32(self):
41+
(
42+
Tester(self.mv3, self.model_inputs)
43+
.export(self.export_stage)
44+
.to_edge()
45+
.check(list(self.all_operators))
46+
.partition()
47+
.check(["torch.ops.executorch_call_delegate"])
48+
.check_not(list(self.all_operators))
49+
.to_executorch()
50+
.serialize()
51+
.run_method()
52+
.compare_outputs()
53+
)
54+
55+
def test_qs8_pt2e(self):
56+
ops_after_quantization = self.all_operators - {
57+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
58+
}
59+
ops_after_lowering = self.all_operators - {
60+
# TODO: unified partitioner since hardswish/hardsigmoid decomposed operators are not quantized
61+
# They will not be partitioned by quantized partitioner
62+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
63+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
64+
"executorch_exir_dialects_edge__ops_aten_div_Tensor",
65+
"executorch_exir_dialects_edge__ops_aten_clamp_default",
66+
"executorch_exir_dialects_edge__ops_aten__to_copy_default",
67+
}
68+
69+
(
70+
Tester(self.mv3, self.model_inputs)
71+
.quantize2()
72+
.export(self.export_stage)
73+
.to_edge()
74+
.check(list(ops_after_quantization))
75+
.partition(Partition(partitioner=XnnpackQuantizedPartitioner2))
76+
.check(["torch.ops.executorch_call_delegate"])
77+
.check_not(list(ops_after_lowering))
78+
.to_executorch()
79+
.serialize()
80+
.run_method()
81+
.compare_outputs()
82+
)

0 commit comments

Comments
 (0)