From f637f9b9a109a2d495ec13d400bde12491e94158 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 24 Jul 2025 18:27:04 -0700 Subject: [PATCH 1/6] Update [ghstack-poisoned] --- .../group_partitioner.py | 63 ++++++++++++++----- 1 file changed, 48 insertions(+), 15 deletions(-) diff --git a/exir/backend/canonical_partitioners/group_partitioner.py b/exir/backend/canonical_partitioners/group_partitioner.py index 9aaeacd9d0a..ce08edc8a66 100644 --- a/exir/backend/canonical_partitioners/group_partitioner.py +++ b/exir/backend/canonical_partitioners/group_partitioner.py @@ -101,19 +101,25 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id): p2_nodes = set(partitions_by_id[p2].nodes.keys()) combined_nodes = p1_nodes.union(p2_nodes) - for node in combined_nodes: - # Get all downstream nodes that are not in the combined partition - external_downstreams = { - n - for n in self.dependency_viewer.downstreams_of(node) - if n not in combined_nodes - } + user_nodes = [] + # topologically, p2_nodes comes before p1_nodes, so we only + # need to check the downstream nodes of p2. + # Additionally, we don't need to check all the downstream nodes + # of p2, we only need to check the nodes directly outside of p2. + # example: + # partition[a --> b --> c] --> d --> e --> f + # we don't need to check [d, e, f] we only need to check [d] because + # the downstream users of [d] will include [e, f] + for node in p2_nodes: + for user in node.users: + if user not in combined_nodes: + user_nodes.append(user) + for external_node in user_nodes: # Check if any external downstream nodes have downstream nodes in the combined partition - for external_node in external_downstreams: - downstream_nodes = self.dependency_viewer.downstreams_of(external_node) - if any(n in combined_nodes for n in downstream_nodes): - return False + downstream_nodes = self.dependency_viewer.downstreams_of(external_node) + if any(n in combined_nodes for n in downstream_nodes): + return False return True @@ -133,13 +139,35 @@ def _process_node_groups( if not self.node_groups: return group_to_partition_id - for i, group in enumerate(self.node_groups): - # Create a partition for each group + node_to_group_index = {} + for idx, group in enumerate(self.node_groups): + for node in group: + node_to_group_index[node] = idx + + processed_nodes = set() + + # We have to create the partitions in reverse topological order + # so we find the groups as we traverse backwards in the graph + # this likely needs to be combined with the process_remaining_nodes + # TODO: this currently doesn't work with _process_remaining_nodes so + # if a user provides grouped nodes with operatorsupport, then this will + # faile + for node in reversed(self.graph_module.graph.nodes): + if node not in node_to_group_index: + continue + + if node in processed_nodes: + continue + + group_idx = node_to_group_index[node] + group = self.node_groups[group_idx] + + # Create a partition for group partition_id = next(new_partition_id) partition = Partition(id=partition_id, nodes=set()) partitions_by_id[partition_id] = partition partitions_order[partition_id] = partition_id - group_to_partition_id[i] = partition_id + group_to_partition_id[group_idx] = partition_id # Add all supported nodes from the group to the partition for node in group: @@ -164,6 +192,12 @@ def _process_node_groups( partition_map[partition_id].add(target_id) partition_map[partition_id].update(partition_map[target_id]) + # all the nodes in the group have now been processed + # so skip if we encoutner them again in our rev topo + # iteration + for node in group: + processed_nodes.add(node) + return group_to_partition_id def _process_remaining_nodes( @@ -209,7 +243,6 @@ def _merge_partitions( # Set to track removed partitions from initial static list so we can skip them already_merged = set() - # Try to merge each pair of partitions for i, p1 in enumerate(partition_ids): # Skip if this partition has been already merged From 2b48db92557ea2aff1ad5cfdd8ecc5dc6e1d709f Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 24 Jul 2025 18:27:12 -0700 Subject: [PATCH 2/6] Update [ghstack-poisoned] --- .../config_partitioner.py | 49 +++++++++++++++---- .../pattern_op_partitioner.py | 48 ++++++++++++++++++ 2 files changed, 87 insertions(+), 10 deletions(-) diff --git a/exir/backend/canonical_partitioners/config_partitioner.py b/exir/backend/canonical_partitioners/config_partitioner.py index 1a9bcc33e80..f9ef44cbf22 100644 --- a/exir/backend/canonical_partitioners/config_partitioner.py +++ b/exir/backend/canonical_partitioners/config_partitioner.py @@ -10,13 +10,17 @@ import torch from executorch.exir.backend.backend_details import ExportedProgram from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( - generate_partitions_from_list_of_nodes, + generate_grouped_partitions_from_list_of_nodes, ) from executorch.exir.backend.partitioner import ( DelegationSpec, Partitioner, PartitionResult, ) + +from exir.backend.canonical_partitioners.pattern_op_partitioner import ( + generate_grouped_partitions_from_list_of_nodes, +) from torch.fx.passes.infra.partitioner import Partition @@ -162,23 +166,48 @@ def filter_fn(node: torch.fx.Node) -> bool: def get_matched_nodes_from_configs( self, ep: ExportedProgram ) -> List[List[torch.fx.Node]]: + # disjoint set union + parent = {} + + def find(x): + parent.setdefault(x, x) + if parent[x] != x: + parent[x] = find(parent[x]) + return parent[x] + + def union(x, y): + parent[find(x)] = find(y) + # gather supported nodes - matched_nodes = [] gm = ep.graph_module for node in gm.graph.nodes: - if node.op == "call_function": - target = format_target_name(node.target.__name__) - if target in self.target_partitioner_configs: - node_config = self.target_partitioner_configs[target] - if node_config.check_constraints(node, ep): - matched_nodes.append(node_config.get_partition(node, ep)) + if node.op != "call_function": + continue + target = format_target_name(node.target.__name__) + + if target not in self.target_partitioner_configs: + continue + + node_config = self.target_partitioner_configs[target] + if not node_config.check_constraints(node, ep): + continue + + partition = node_config.get_partition(node, ep) + parent[partition[0]] = partition[0] + for i in range(1, len(partition)): + union(partition[0], partition[i]) + + groups = {} + for node in parent.keys(): + root = find(node) + groups.setdefault(root, set()).add(node) - return matched_nodes + return [list(group) for group in groups.values()] def generate_partitions(self, ep: ExportedProgram) -> List[Partition]: matched_nodes = self.get_matched_nodes_from_configs(ep) # create partitions - partitions = generate_partitions_from_list_of_nodes( + partitions = generate_grouped_partitions_from_list_of_nodes( ep.graph_module, matched_nodes, ) diff --git a/exir/backend/canonical_partitioners/pattern_op_partitioner.py b/exir/backend/canonical_partitioners/pattern_op_partitioner.py index 7a3c943d258..3d11a80b2ee 100644 --- a/exir/backend/canonical_partitioners/pattern_op_partitioner.py +++ b/exir/backend/canonical_partitioners/pattern_op_partitioner.py @@ -8,6 +8,10 @@ from typing import List, Optional import torch + +from executorch.exir.backend.canonical_partitioners.group_partitioner import ( + GroupBasedPartitioner, +) from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.passes.operator_support import any_chain, OperatorSupportBase from torch.fx.passes.utils.matcher_utils import SubgraphMatcher @@ -56,6 +60,50 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: return partition_list +def generate_grouped_partitions_from_list_of_nodes( + graph_module: torch.fx.GraphModule, + pattern_list: Optional[List[List[torch.fx.Node]]] = None, + op_support: Optional[OperatorSupportBase] = None, +) -> List[Partition]: + final_op_support: Optional[OperatorSupportBase] = op_support + + if pattern_list is not None: + # Tag all the nodes in these patterns + for node_list in pattern_list: + for node in node_list: + node.meta["match"] = True + + class MatchTag(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.meta.get("match", False) + + final_op_support = ( + MatchTag() + if final_op_support is None + else any_chain(final_op_support, MatchTag()) + ) + + assert ( + final_op_support is not None + ), "Did not give a pattern or OperatorSupportBase instance to partition with" + + # Run the CapabilityBasedPartitioner to return the largest possible + # subgraphs containing the nodes with the tags + group_partitioner = GroupBasedPartitioner( + graph_module, + final_op_support, + node_groups=pattern_list, + allows_single_node_partition=True, + ) + partition_list = group_partitioner.propose_partitions() + + # Remove the metadata field we added + for partition in partition_list: + for node in partition.nodes: + node.meta.pop("match", False) + return partition_list + + def generate_pattern_op_partitions( graph_module: torch.fx.GraphModule, patterns: Optional[List[torch.fx.Graph]] = None, From 022022aa06db250cb9abbfe5d14cc0b6e51076d5 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 24 Jul 2025 18:49:35 -0700 Subject: [PATCH 3/6] Update [ghstack-poisoned] --- exir/backend/canonical_partitioners/config_partitioner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/exir/backend/canonical_partitioners/config_partitioner.py b/exir/backend/canonical_partitioners/config_partitioner.py index f9ef44cbf22..de7224de081 100644 --- a/exir/backend/canonical_partitioners/config_partitioner.py +++ b/exir/backend/canonical_partitioners/config_partitioner.py @@ -193,9 +193,10 @@ def union(x, y): continue partition = node_config.get_partition(node, ep) - parent[partition[0]] = partition[0] - for i in range(1, len(partition)): - union(partition[0], partition[i]) + if len(partition) > 0: + parent[partition[0]] = partition[0] + for i in range(1, len(partition)): + union(partition[0], partition[i]) groups = {} for node in parent.keys(): From ec7a6b5e44ea7184543618c7e7877b3ccd8d1efb Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 24 Jul 2025 23:42:46 -0700 Subject: [PATCH 4/6] Update [ghstack-poisoned] --- exir/backend/canonical_partitioners/group_partitioner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/exir/backend/canonical_partitioners/group_partitioner.py b/exir/backend/canonical_partitioners/group_partitioner.py index 74f782f597a..63bedad3b42 100644 --- a/exir/backend/canonical_partitioners/group_partitioner.py +++ b/exir/backend/canonical_partitioners/group_partitioner.py @@ -139,7 +139,6 @@ def _process_node_groups( if not self.node_groups: return group_to_partition_id - processed_nodes = set() # We have to create the partitions in reverse topological order From 2034cdde32800d5972acfd1fb7fd6c386c117721 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 25 Jul 2025 10:58:19 -0700 Subject: [PATCH 5/6] Update [ghstack-poisoned] --- .../config_partitioner.py | 73 ++++++++++++++----- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/exir/backend/canonical_partitioners/config_partitioner.py b/exir/backend/canonical_partitioners/config_partitioner.py index 440f7b56d39..16933376807 100644 --- a/exir/backend/canonical_partitioners/config_partitioner.py +++ b/exir/backend/canonical_partitioners/config_partitioner.py @@ -17,9 +17,18 @@ Partitioner, PartitionResult, ) + +from sympy.logic.boolalg import disjuncts +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param from torch.fx.passes.infra.partitioner import Partition +def is_constant_data(ep: ExportedProgram, node: torch.fx.Node) -> bool: + return ( + is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node) + ) + + def format_target_name(target_name: str) -> str: """ We remove the dialect name space from the target name. We generally @@ -100,6 +109,35 @@ def get_partition( pass +class DSJ: + """ + Disjoint set union data structure used to find connected components in the graph. + """ + + def __init__(self): + self.parent = {} + + def find(self, x): + self.parent.setdefault(x, x) + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] + + def union(self, x, y): + self.parent[self.find(x)] = self.find(y) + + def contains(self, x): + return x in self.parent + + def gen_groups(self): + groups = {} + for node in self.parent.keys(): + root = self.find(node) + groups.setdefault(root, set()).add(node) + + return [list(group) for group in groups.values()] + + class ConfigerationBasedPartitioner(Partitioner): def __init__( self, @@ -162,17 +200,8 @@ def filter_fn(node: torch.fx.Node) -> bool: def get_matched_nodes_from_configs( self, ep: ExportedProgram ) -> List[List[torch.fx.Node]]: - # disjoint set union - parent = {} - - def find(x): - parent.setdefault(x, x) - if parent[x] != x: - parent[x] = find(parent[x]) - return parent[x] - - def union(x, y): - parent[find(x)] = find(y) + # disjoint set union for merging partitions + dsj = DSJ() # gather supported nodes gm = ep.graph_module @@ -188,18 +217,22 @@ def union(x, y): if not node_config.check_constraints(node, ep): continue - partition = node_config.get_partition(node, ep) + partition_candidate = node_config.get_partition(node, ep) + partition = [] + for node in partition_candidate: + # partitioner infra copies constant data across partitions, so it + # is ok if this partition doesn't have it + if is_constant_data(ep, node) and dsj.contains(node): + continue + partition.append(node) + + # Union overlaps into a single group if len(partition) > 0: - parent[partition[0]] = partition[0] + dsj.find(partition[0]) for i in range(1, len(partition)): - union(partition[0], partition[i]) + dsj.union(partition[0], partition[i]) - groups = {} - for node in parent.keys(): - root = find(node) - groups.setdefault(root, set()).add(node) - - return [list(group) for group in groups.values()] + return dsj.gen_groups() def generate_partitions(self, ep: ExportedProgram) -> List[Partition]: matched_nodes = self.get_matched_nodes_from_configs(ep) From 7ff7d32cd19fc0ede389a219e32a831529aa044e Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 25 Jul 2025 12:03:40 -0700 Subject: [PATCH 6/6] Update [ghstack-poisoned] --- exir/backend/canonical_partitioners/config_partitioner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/exir/backend/canonical_partitioners/config_partitioner.py b/exir/backend/canonical_partitioners/config_partitioner.py index 16933376807..09835cd2b59 100644 --- a/exir/backend/canonical_partitioners/config_partitioner.py +++ b/exir/backend/canonical_partitioners/config_partitioner.py @@ -18,7 +18,6 @@ PartitionResult, ) -from sympy.logic.boolalg import disjuncts from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param from torch.fx.passes.infra.partitioner import Partition