From f637f9b9a109a2d495ec13d400bde12491e94158 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 24 Jul 2025 18:27:04 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- .../group_partitioner.py | 63 ++++++++++++++----- 1 file changed, 48 insertions(+), 15 deletions(-) diff --git a/exir/backend/canonical_partitioners/group_partitioner.py b/exir/backend/canonical_partitioners/group_partitioner.py index 9aaeacd9d0a..ce08edc8a66 100644 --- a/exir/backend/canonical_partitioners/group_partitioner.py +++ b/exir/backend/canonical_partitioners/group_partitioner.py @@ -101,19 +101,25 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id): p2_nodes = set(partitions_by_id[p2].nodes.keys()) combined_nodes = p1_nodes.union(p2_nodes) - for node in combined_nodes: - # Get all downstream nodes that are not in the combined partition - external_downstreams = { - n - for n in self.dependency_viewer.downstreams_of(node) - if n not in combined_nodes - } + user_nodes = [] + # topologically, p2_nodes comes before p1_nodes, so we only + # need to check the downstream nodes of p2. + # Additionally, we don't need to check all the downstream nodes + # of p2, we only need to check the nodes directly outside of p2. + # example: + # partition[a --> b --> c] --> d --> e --> f + # we don't need to check [d, e, f] we only need to check [d] because + # the downstream users of [d] will include [e, f] + for node in p2_nodes: + for user in node.users: + if user not in combined_nodes: + user_nodes.append(user) + for external_node in user_nodes: # Check if any external downstream nodes have downstream nodes in the combined partition - for external_node in external_downstreams: - downstream_nodes = self.dependency_viewer.downstreams_of(external_node) - if any(n in combined_nodes for n in downstream_nodes): - return False + downstream_nodes = self.dependency_viewer.downstreams_of(external_node) + if any(n in combined_nodes for n in downstream_nodes): + return False return True @@ -133,13 +139,35 @@ def _process_node_groups( if not self.node_groups: return group_to_partition_id - for i, group in enumerate(self.node_groups): - # Create a partition for each group + node_to_group_index = {} + for idx, group in enumerate(self.node_groups): + for node in group: + node_to_group_index[node] = idx + + processed_nodes = set() + + # We have to create the partitions in reverse topological order + # so we find the groups as we traverse backwards in the graph + # this likely needs to be combined with the process_remaining_nodes + # TODO: this currently doesn't work with _process_remaining_nodes so + # if a user provides grouped nodes with operatorsupport, then this will + # faile + for node in reversed(self.graph_module.graph.nodes): + if node not in node_to_group_index: + continue + + if node in processed_nodes: + continue + + group_idx = node_to_group_index[node] + group = self.node_groups[group_idx] + + # Create a partition for group partition_id = next(new_partition_id) partition = Partition(id=partition_id, nodes=set()) partitions_by_id[partition_id] = partition partitions_order[partition_id] = partition_id - group_to_partition_id[i] = partition_id + group_to_partition_id[group_idx] = partition_id # Add all supported nodes from the group to the partition for node in group: @@ -164,6 +192,12 @@ def _process_node_groups( partition_map[partition_id].add(target_id) partition_map[partition_id].update(partition_map[target_id]) + # all the nodes in the group have now been processed + # so skip if we encoutner them again in our rev topo + # iteration + for node in group: + processed_nodes.add(node) + return group_to_partition_id def _process_remaining_nodes( @@ -209,7 +243,6 @@ def _merge_partitions( # Set to track removed partitions from initial static list so we can skip them already_merged = set() - # Try to merge each pair of partitions for i, p1 in enumerate(partition_ids): # Skip if this partition has been already merged From ec7a6b5e44ea7184543618c7e7877b3ccd8d1efb Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 24 Jul 2025 23:42:46 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- exir/backend/canonical_partitioners/group_partitioner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/exir/backend/canonical_partitioners/group_partitioner.py b/exir/backend/canonical_partitioners/group_partitioner.py index 74f782f597a..63bedad3b42 100644 --- a/exir/backend/canonical_partitioners/group_partitioner.py +++ b/exir/backend/canonical_partitioners/group_partitioner.py @@ -139,7 +139,6 @@ def _process_node_groups( if not self.node_groups: return group_to_partition_id - processed_nodes = set() # We have to create the partitions in reverse topological order