Skip to content

Commit 41125f9

Browse files
mcr229facebook-github-bot
authored andcommitted
add unsupported module list for partitioner (#139)
Summary: Pull Request resolved: #139 For Quantized Mobilenetv3. Right now decomposed hardswish and hardsigmoid gets partitioned with XNNPACKQuantizedPartitioner2, because they contain simple ops like add, mul, div, etc. These ops are all in FP32 since xnnpack doesn't support the quantized variants. Since we don't currently have support for a unified partition, we want to block the from our Quantized Partitioner. Reviewed By: digantdesai Differential Revision: D48667680 fbshipit-source-id: 20f7a82aa2f3f383b36b908951b32fbc0184fdb4
1 parent 76b1385 commit 41125f9

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

backends/xnnpack/partition/configs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@
8181
)
8282
}
8383

84+
UNSUPPORTED_QUANT_MODULES = [
85+
torch.nn.Hardswish,
86+
torch.nn.Hardsigmoid,
87+
]
88+
8489
# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support
8590
SUPPORTED_QUANT_MODULES = [
8691
torch.clamp,

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
SUPPORTED_OPS,
1818
SUPPORTED_QUANT_MODULES,
1919
SUPPORTED_QUANT_OPS,
20+
UNSUPPORTED_QUANT_MODULES,
2021
)
2122
from executorch.backends.xnnpack.partition.support_patterns import (
2223
get_add_graphs,
@@ -103,17 +104,25 @@ def __init__(
103104
Any, Callable[[torch.fx.Node], bool]
104105
] = _OP_SUPPORT_CONSTRAINTS,
105106
supported_ops: Optional[List] = None,
107+
unsupported_modules: Optional[List] = None,
106108
):
107109
"""
108110
@Arg constraints_dict: Dict mapping each node to a lambda function that
109111
returns True if backend constraints are met for that instance of the
110112
node.
111113
@Arg supported_ops: List of supported operators for partitioning
112114
"""
115+
self.unsupported_modules = unsupported_modules
113116
self.supported_ops = supported_ops
114117
self.constraints = constraints_dict
115118
assert len(self.constraints)
116119

120+
def check_common_constraints(self, node) -> bool:
121+
if self.unsupported_modules and "source_fn" in node.meta:
122+
return not node.meta["source_fn"][1] in self.unsupported_modules
123+
124+
return True
125+
117126
@staticmethod
118127
def check_constraint(node) -> bool:
119128
"""
@@ -132,7 +141,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
132141
if self.supported_ops and node.target not in self.supported_ops:
133142
return False
134143

135-
return self.check_constraint(node)
144+
return self.check_constraint(node) and self.check_common_constraints(node)
136145

137146
def _constraint(target): # noqa
138147
"""
@@ -540,10 +549,11 @@ def __init__(
540549
self,
541550
supported_modules: List[Callable] = SUPPORTED_MODULES,
542551
supported_ops: Optional[List[Callable]] = SUPPORTED_OPS,
552+
unsupported_modules: Optional[List[Callable]] = None,
543553
):
544554
super().__init__()
545555
self.supported_modules = set(supported_modules)
546-
556+
self.unsupported_modules = unsupported_modules
547557
self.supported_ops = set(supported_ops or [])
548558

549559
self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, [])
@@ -614,7 +624,10 @@ def generate_partitions(self, graph_module: torch.fx.GraphModule) -> List[Any]:
614624
return generate_partitions_from_list_of_nodes(
615625
graph_module,
616626
matched_module_nodes,
617-
XnnpackOperatorSupport(supported_ops=self.supported_ops),
627+
XnnpackOperatorSupport(
628+
supported_ops=self.supported_ops,
629+
unsupported_modules=self.unsupported_modules,
630+
),
618631
)
619632

620633
def tag_nodes(self, partitions: List[Partition]) -> None:
@@ -668,9 +681,12 @@ def __init__(
668681
self,
669682
supported_modules=SUPPORTED_QUANT_MODULES,
670683
supported_ops=SUPPORTED_QUANT_OPS,
684+
unsupported_modules=UNSUPPORTED_QUANT_MODULES,
671685
):
672686
supported_ops = supported_ops or []
673-
super().__init__(supported_modules, supported_ops + self._QUANT_OPS)
687+
super().__init__(
688+
supported_modules, supported_ops + self._QUANT_OPS, unsupported_modules
689+
)
674690

675691
# TODO Refactor this
676692
# TODO Don't be greedy when pulling q->dq pairs for a given op, add convert tracker pass

0 commit comments

Comments
 (0)