diff --git a/exir/backend/canonical_partitioners/group_partitioner.py b/exir/backend/canonical_partitioners/group_partitioner.py index 9aaeacd9d0a..63bedad3b42 100644 --- a/exir/backend/canonical_partitioners/group_partitioner.py +++ b/exir/backend/canonical_partitioners/group_partitioner.py @@ -86,7 +86,7 @@ def __init__( ) self.node_to_group = collections.defaultdict(int) self.all_nodes_in_groups = set() - if node_groups: + if self.node_groups: for i, group in enumerate(self.node_groups): for node in group: # Node is in multiple groups - not allowed @@ -101,19 +101,25 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id): p2_nodes = set(partitions_by_id[p2].nodes.keys()) combined_nodes = p1_nodes.union(p2_nodes) - for node in combined_nodes: - # Get all downstream nodes that are not in the combined partition - external_downstreams = { - n - for n in self.dependency_viewer.downstreams_of(node) - if n not in combined_nodes - } + user_nodes = [] + # topologically, p2_nodes comes before p1_nodes, so we only + # need to check the downstream nodes of p2. + # Additionally, we don't need to check all the downstream nodes + # of p2, we only need to check the nodes directly outside of p2. + # example: + # partition[a --> b --> c] --> d --> e --> f + # we don't need to check [d, e, f] we only need to check [d] because + # the downstream users of [d] will include [e, f] + for node in p2_nodes: + for user in node.users: + if user not in combined_nodes: + user_nodes.append(user) + for external_node in user_nodes: # Check if any external downstream nodes have downstream nodes in the combined partition - for external_node in external_downstreams: - downstream_nodes = self.dependency_viewer.downstreams_of(external_node) - if any(n in combined_nodes for n in downstream_nodes): - return False + downstream_nodes = self.dependency_viewer.downstreams_of(external_node) + if any(n in combined_nodes for n in downstream_nodes): + return False return True @@ -133,13 +139,30 @@ def _process_node_groups( if not self.node_groups: return group_to_partition_id - for i, group in enumerate(self.node_groups): - # Create a partition for each group + processed_nodes = set() + + # We have to create the partitions in reverse topological order + # so we find the groups as we traverse backwards in the graph + # this likely needs to be combined with the process_remaining_nodes + # TODO: this currently doesn't work with _process_remaining_nodes so + # if a user provides grouped nodes with operatorsupport, then this will + # faile + for node in reversed(self.graph_module.graph.nodes): + if node not in self.node_to_group: + continue + + if node in processed_nodes: + continue + + group_idx = self.node_to_group[node] + group = self.node_groups[group_idx] + + # Create a partition for group partition_id = next(new_partition_id) partition = Partition(id=partition_id, nodes=set()) partitions_by_id[partition_id] = partition partitions_order[partition_id] = partition_id - group_to_partition_id[i] = partition_id + group_to_partition_id[group_idx] = partition_id # Add all supported nodes from the group to the partition for node in group: @@ -164,6 +187,12 @@ def _process_node_groups( partition_map[partition_id].add(target_id) partition_map[partition_id].update(partition_map[target_id]) + # all the nodes in the group have now been processed + # so skip if we encoutner them again in our rev topo + # iteration + for node in group: + processed_nodes.add(node) + return group_to_partition_id def _process_remaining_nodes( @@ -209,7 +238,6 @@ def _merge_partitions( # Set to track removed partitions from initial static list so we can skip them already_merged = set() - # Try to merge each pair of partitions for i, p1 in enumerate(partition_ids): # Skip if this partition has been already merged