Skip to content

Commit 925cd1b

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Unified XnnpackPartitioner (#265)
Summary: Pull Request resolved: #265 Reviewed By: mcr229 Differential Revision: D48761226 fbshipit-source-id: 484f22c3a4f498ef3024bf7bb5afbc06f7659e93
1 parent 71470a7 commit 925cd1b

File tree

2 files changed

+272
-5
lines changed

2 files changed

+272
-5
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 269 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import itertools
88
import logging
99
import operator
10-
from typing import Any, Callable, cast, Dict, List, Optional, Union
10+
from typing import Any, Callable, cast, Dict, List, Optional, Set, Union
1111

1212
import torch
1313

@@ -528,6 +528,274 @@ def get_nodes(self, src_partition: SourcePartition) -> List[torch.fx.Node]: # n
528528
)
529529

530530

531+
class XnnpackPartitioner(Partitioner):
532+
"""
533+
Module and Opname based partitioner for FP32 modules/ops listed in
534+
SUPPORTED_MODULES and SUPPORTED_OPS and statically quantized modules/ops listed in
535+
SUPPORTED_QUANT_MODULES and SUPPORTED_QUANT_OPS.
536+
"""
537+
538+
_Q_OPS = [
539+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
540+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
541+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
542+
]
543+
544+
_DQ_OPS = [
545+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
546+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
547+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
548+
]
549+
550+
_QPARAM_OPS = [
551+
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
552+
]
553+
554+
_QUANT_OPS = _Q_OPS + _DQ_OPS + _QPARAM_OPS
555+
556+
def __init__(
557+
self,
558+
*,
559+
supported_modules: List[Callable] = SUPPORTED_MODULES,
560+
supported_ops: Optional[List[Callable]] = SUPPORTED_OPS,
561+
supported_quant_modules: List[Callable] = SUPPORTED_QUANT_MODULES,
562+
supported_quant_ops: Optional[List[Callable]] = SUPPORTED_QUANT_OPS,
563+
quant: Optional[bool] = None,
564+
):
565+
super().__init__()
566+
self.supported_modules = set(supported_modules)
567+
self.supported_ops = set(supported_ops or [])
568+
self.supported_quant_modules = set(supported_quant_modules)
569+
supported_quant_ops = supported_quant_ops or []
570+
self.supported_quant_ops = set(supported_quant_ops + self._QUANT_OPS)
571+
572+
self.quant = quant
573+
574+
self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, [])
575+
self.partition_tags: Dict[str, DelegationSpec] = {}
576+
577+
def get_supported_modules(self, quant: bool) -> Set[Callable]:
578+
"""
579+
Get supported modules
580+
"""
581+
if quant is True:
582+
return self.supported_quant_modules
583+
elif quant is False:
584+
return self.supported_modules
585+
else:
586+
return self.supported_modules | self.supported_quant_modules
587+
588+
def get_supported_ops(self, quant: Optional[bool]) -> Set[Callable]:
589+
"""
590+
Get supported ops
591+
"""
592+
if quant is True:
593+
return self.supported_quant_ops
594+
elif quant is False:
595+
return self.supported_ops
596+
else:
597+
return self.supported_ops | self.supported_quant_ops
598+
599+
@staticmethod
600+
def check_partitions(partitions: Union[dict, list]) -> bool:
601+
"""
602+
Warn users if there aren't any matches
603+
604+
TODO: convert this into a stronger validation, may need a flag in
605+
`to_backend()` or partitioner __init__()
606+
"""
607+
pl = len(partitions)
608+
if pl == 0:
609+
log.warning("Nothing can be partitioned!")
610+
else:
611+
log.info(f"Found {pl} subgraphs to be partitioned.")
612+
return pl != 0
613+
614+
def get_input_deps( # noqa
615+
self, input_nodes: List[torch.fx.Node]
616+
) -> List[torch.fx.Node]:
617+
"""
618+
For each input node, walk up and pull necessary quant/attr nodes in the partition
619+
"""
620+
nodes = set()
621+
for inp in input_nodes:
622+
if inp.target in self._DQ_OPS:
623+
# dequant node
624+
nodes.add(inp)
625+
626+
# possible per_channel scale/zp for the dequant node args{1, 2}
627+
for i in [1, 2]:
628+
node = inp.args[i]
629+
if isinstance(node, torch.fx.Node) and node.op == "get_attr":
630+
nodes.add(node)
631+
632+
# quant node
633+
q_prod = inp.args[0]
634+
assert (
635+
isinstance(q_prod, torch.fx.Node) and q_prod.target in self._Q_OPS
636+
)
637+
nodes.add(q_prod)
638+
639+
# possible weight for the quant node arg{0}
640+
node = q_prod.args[0]
641+
if isinstance(node, torch.fx.Node) and node.op == "get_attr":
642+
nodes.add(node)
643+
644+
# possible nodes for quant node args{1, 2}
645+
for i in [1, 2]:
646+
node = q_prod.args[i]
647+
# possible choose_qparam
648+
if (
649+
isinstance(node, torch.fx.Node)
650+
and node.op == "call_function"
651+
and node.target == operator.getitem
652+
):
653+
parent = node.args[0]
654+
if (
655+
isinstance(parent, torch.fx.Node)
656+
and parent.op == "call_function"
657+
and parent.target in self._QPARAM_OPS
658+
):
659+
nodes.add(node)
660+
nodes.add(parent)
661+
662+
# possible per_channel scale/zp for the quant node
663+
elif isinstance(node, torch.fx.Node) and node.op == "get_attr":
664+
nodes.add(node)
665+
return list(nodes)
666+
667+
def get_output_deps(self, output_nodes: List[torch.fx.Node]) -> List[torch.fx.Node]:
668+
"""
669+
For each output node, check all the users and insert them into the partition if needed
670+
"""
671+
nodes = []
672+
for output in output_nodes:
673+
for node in output.users:
674+
if node.target in self._Q_OPS:
675+
nodes.append(node)
676+
users = list(node.users.keys())
677+
for dq_user in users:
678+
assert (
679+
dq_user.target in self._DQ_OPS
680+
), "Expecting a dq node(s) after a q node, but got target {dq_user.target} for {dq_user} node"
681+
nodes.append(dq_user)
682+
return nodes
683+
684+
def get_nodes(
685+
self, src_partition: SourcePartition, quant: bool
686+
) -> List[torch.fx.Node]:
687+
"""
688+
Return nodes from the source partition.
689+
"""
690+
if quant:
691+
# Insert quantization ops into src_partition by following the input, output node.
692+
return (
693+
src_partition.nodes
694+
+ self.get_input_deps(src_partition.input_nodes)
695+
+ self.get_output_deps(src_partition.output_nodes)
696+
)
697+
else:
698+
return src_partition.nodes
699+
700+
def qualify_nodes(self, input_nodes: List[torch.fx.Node]) -> bool:
701+
"""
702+
Each node in the module (post decomposition) must satisfy the
703+
constraints specified for XNNPACK.
704+
705+
Disqualify the whole module if one of the nodes fails to satisfy.
706+
"""
707+
return all(
708+
XnnpackOperatorSupport.check_constraint(node) for node in input_nodes
709+
)
710+
711+
def get_module_partitions(
712+
self, graph_module: torch.fx.GraphModule, quant: Optional[bool]
713+
) -> List[List[torch.fx.Node]]:
714+
"""
715+
Get all partitions in the torch.fx.GraphModule for the supported
716+
modules.
717+
"""
718+
719+
if quant is None:
720+
module_partitions = self.get_module_partitions(graph_module, True)
721+
for node_list in module_partitions:
722+
for node in node_list:
723+
node.meta["quant_match"] = True
724+
fp32_module_partitions = self.get_module_partitions(graph_module, False)
725+
for node_list in fp32_module_partitions:
726+
for node in node_list:
727+
if node.meta.get("quant_match", False):
728+
break
729+
else:
730+
module_partitions.append(node_list)
731+
for node_list in module_partitions:
732+
for node in node_list:
733+
node.meta.pop("quant_match", False)
734+
return module_partitions
735+
736+
src_partition_dict = get_source_partitions(
737+
graph_module.graph, self.get_supported_modules(quant)
738+
)
739+
all_partitions = src_partition_dict.values()
740+
741+
module_partitions = []
742+
for src_partitions in all_partitions:
743+
for src_partition in src_partitions:
744+
partition_nodes = self.get_nodes(src_partition, quant)
745+
if self.qualify_nodes(partition_nodes):
746+
module_partitions.append(partition_nodes)
747+
748+
return module_partitions
749+
750+
def generate_partitions(
751+
self, graph_module: torch.fx.GraphModule, quant: Optional[bool]
752+
) -> List[Any]:
753+
"""
754+
Generate a list of partitions for an torch.fx.GraphModule.
755+
Also pass the supported ops to match.
756+
"""
757+
matched_module_nodes = self.get_module_partitions(graph_module, quant)
758+
return generate_partitions_from_list_of_nodes(
759+
graph_module,
760+
matched_module_nodes,
761+
XnnpackOperatorSupport(supported_ops=list(self.get_supported_ops(quant))),
762+
)
763+
764+
def tag_nodes(self, partitions: List[Partition]) -> None:
765+
"""
766+
Tag each partition in the list with its delegation tag.
767+
"""
768+
for partition in partitions:
769+
# Add delegation tags
770+
skip = False
771+
for node in partition.nodes:
772+
if "delegation_tag" in node.meta:
773+
skip = True
774+
if skip:
775+
continue
776+
for node in partition.nodes:
777+
delegation_tag = f"tag{partition.id}"
778+
node.meta["delegation_tag"] = delegation_tag
779+
self.partition_tags[delegation_tag] = self.delegation_spec
780+
781+
# override
782+
def _partition(
783+
self, graph_module: torch.fx.GraphModule, quant: Optional[bool]
784+
) -> torch.fx.GraphModule:
785+
"""
786+
Run the partitioner on the given graph module, then tag each partition
787+
with its delegation tag (and partition id)
788+
"""
789+
partitions = self.generate_partitions(graph_module, quant)
790+
if self.check_partitions(partitions):
791+
self.tag_nodes(partitions)
792+
return graph_module
793+
794+
def partition(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
795+
ret = self._partition(graph_module, self.quant)
796+
return ret
797+
798+
531799
class XnnpackDynamicallyQuantizedPartitioner(XnnpackQuantizedPartitioner):
532800
def __init__(
533801
self,

backends/xnnpack/test/test_xnnpack_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414

1515
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
1616
XnnpackDynamicallyQuantizedPartitioner,
17-
XnnpackFloatingPointPartitioner,
18-
XnnpackQuantizedPartitioner,
17+
XnnpackPartitioner,
1918
)
2019
from executorch.backends.xnnpack.utils.configs import (
2120
get_transform_passes,
@@ -185,9 +184,9 @@ def forward(self, *args):
185184
if quantized_dynamic:
186185
partitioner = XnnpackDynamicallyQuantizedPartitioner
187186
else:
188-
partitioner = XnnpackQuantizedPartitioner
187+
partitioner = XnnpackPartitioner
189188
else:
190-
partitioner = XnnpackFloatingPointPartitioner
189+
partitioner = XnnpackPartitioner
191190

192191
if use_partitioner:
193192
with validation_disabled():

0 commit comments

Comments
 (0)