Skip to content

Commit 4f297a5

Browse files
mcr229facebook-github-bot
authored andcommitted
move modules to ops (#257)
Summary: Pull Request resolved: #257 Decrease the supported module surface. For ops that are canonical (not decomposed) and single op partitions, we can always partition just using the operator rather than the module (It is difficult to maintain every variant of the op's module, and if the op is used in a larger module like multihead attention then it is impossible to partition using source_fn) In the future, we will want to only support module which require recomposition or some module level recomposition Reviewed By: manuelcandales Differential Revision: D49111303 fbshipit-source-id: eb46aa51cd129200fa3ecd45be34675ab0fa0a0e
1 parent 36d1138 commit 4f297a5

File tree

2 files changed

+26
-38
lines changed

2 files changed

+26
-38
lines changed

backends/xnnpack/partition/configs.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
exir_ops.edge.aten.upsample_bilinear2d.default,
2525
exir_ops.edge.aten.mean.dim,
2626
exir_ops.edge.aten.max.dim,
27+
exir_ops.edge.aten.max_pool2d_with_indices.default,
2728
exir_ops.edge.aten.hardtanh.default,
2829
exir_ops.edge.aten.sqrt.default,
2930
exir_ops.edge.aten.ceil.default,
@@ -33,40 +34,49 @@
3334
exir_ops.edge.aten.abs.default,
3435
exir_ops.edge.aten._prelu_kernel.default,
3536
exir_ops.edge.aten.slice_copy.Tensor,
37+
exir_ops.edge.aten.relu.default,
38+
exir_ops.edge.aten.hardtanh.default,
39+
exir_ops.edge.aten.permute_copy.default,
40+
exir_ops.edge.aten.sigmoid.default,
41+
exir_ops.edge.aten._softmax.default,
42+
exir_ops.edge.aten.cat.default,
43+
exir_ops.edge.aten.elu.default,
44+
exir_ops.edge.aten.avg_pool2d.default,
45+
exir_ops.edge.aten.leaky_relu.default,
3646
]
3747

3848
SUPPORTED_MODULES = [
3949
torch.nn.Conv1d,
4050
# TODO(T161981984) recomposed hardswish into a single node
41-
torch.nn.Hardswish,
42-
torch.nn.Hardsigmoid,
43-
torch.nn.Conv2d,
44-
torch.nn.ReLU,
45-
torch.nn.Sigmoid,
46-
torch.nn.Softmax,
47-
torch.nn.BatchNorm1d,
51+
torch.nn.Hardswish, # we need to recompose
52+
torch.nn.Hardsigmoid, # we can handle decomposition
4853
torch.nn.BatchNorm2d,
54+
torch.nn.BatchNorm1d,
55+
torch.nn.Conv2d,
4956
torch.nn.Linear,
5057
torch.nn.functional.linear,
51-
torch.nn.Hardtanh,
52-
torch.nn.MaxPool2d,
53-
torch.nn.LeakyReLU,
54-
torch.nn.ELU,
55-
torch.nn.AvgPool2d,
5658
torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr
57-
torch.cat,
58-
torch.concat,
59-
torch.concatenate,
6059
]
6160

6261
# TODO delete this and should use SUPPORTED_OPS instead once we align fp32 and quant support
6362
SUPPORTED_QUANT_OPS = [
6463
exir_ops.edge.aten.add.Tensor,
64+
exir_ops.edge.aten.clamp.default,
65+
exir_ops.edge.aten.relu.default,
6566
exir_ops.edge.aten.sub.Tensor,
6667
exir_ops.edge.aten.mul.Tensor,
6768
exir_ops.edge.aten.mean.dim,
68-
exir_ops.edge.aten.hardtanh.default, # TODO - which one module or op or both?
69+
exir_ops.edge.aten.hardtanh.default,
6970
exir_ops.edge.aten.slice_copy.Tensor,
71+
exir_ops.edge.aten.permute_copy.default,
72+
exir_ops.edge.aten.hardtanh.default,
73+
exir_ops.edge.aten.mean.dim,
74+
exir_ops.edge.aten.cat.default,
75+
exir_ops.edge.aten.max_pool2d_with_indices.default,
76+
exir_ops.edge.aten.max_pool2d.default,
77+
exir_ops.edge.aten.constant_pad_nd.default,
78+
exir_ops.edge.aten.elu.default,
79+
exir_ops.edge.aten.leaky_relu.default,
7080
]
7181

7282
SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET = {
@@ -75,7 +85,6 @@
7585
SUPPORTED_QUANT_OPS
7686
+ [
7787
exir_ops.edge.aten._to_copy.default,
78-
exir_ops.edge.aten.max_pool2d.default,
7988
exir_ops.edge.aten.linear.default,
8089
]
8190
)
@@ -88,37 +97,18 @@
8897

8998
# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support
9099
SUPPORTED_QUANT_MODULES = [
91-
torch.clamp,
92-
torch.mean,
93-
torch.permute,
94-
torch.permute_copy,
95-
torch.cat,
96-
torch.concat,
97-
torch.concatenate,
98100
torch.nn.Linear,
99101
torch.nn.functional.linear,
100102
# TODO - T158982884
101103
# torch.ao.nn.quantized.reference.modules.linear.Linear,
102-
torch.nn.MaxPool2d,
103104
torch.nn.Conv1d,
104105
torch.nn.functional.conv1d,
105106
torch.ao.nn.quantized.reference.modules.conv.Conv1d,
106107
torch.nn.Conv2d,
107108
torch.nn.functional.conv2d,
108-
torch.nn.functional.pad,
109-
torch.nn.functional.elu,
110109
torch.ao.nn.quantized.reference.modules.conv.Conv2d,
111110
torch.nn.BatchNorm1d,
112111
torch.nn.BatchNorm2d,
113-
torch.nn.ConstantPad2d,
114-
torch.nn.ELU,
115-
torch.nn.Hardtanh,
116-
torch.nn.ReLU,
117-
torch.nn.functional.relu,
118-
torch.nn.functional.relu_,
119-
torch.nn.functional.leaky_relu,
120-
torch.nn.functional.leaky_relu_,
121-
torch.nn.LeakyReLU,
122112
]
123113

124114
SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET = set(SUPPORTED_QUANT_MODULES)

backends/xnnpack/test/test_xnnpack_quantized.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,6 @@ def forward(self, x):
199199

200200
self.quantize_and_test_model(LeakyReLUModule(), example_inputs)
201201

202-
# TODO(T158652796)
203-
@unittest.expectedFailure
204202
def test_xnnpack_leaky_relu2(self):
205203
example_inputs = (torch.randn(1, 3, 3),)
206204

0 commit comments

Comments
 (0)