-
Notifications
You must be signed in to change notification settings - Fork 743
Pass Dependencies When Proposing Partitions VIA New Group-Based Partitioner #12072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
389 changes: 389 additions & 0 deletions
389
exir/backend/canonical_partitioners/group_partitioner.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,389 @@ | ||
| # mypy: allow-untyped-defs | ||
| import collections | ||
| import itertools | ||
| import logging | ||
| from collections.abc import Sequence | ||
| from typing import List, Optional | ||
|
|
||
| from torch.fx.graph_module import GraphModule | ||
| from torch.fx.node import _get_qualified_name, Node | ||
| from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition | ||
| from torch.fx.passes.operator_support import OperatorSupportBase | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
| logger.setLevel(logging.WARNING) | ||
|
|
||
|
|
||
| class _DependencyViewer: | ||
| def __init__(self, graph_module: GraphModule): | ||
| self.downstreams = collections.defaultdict(set) | ||
| self.upstreams = collections.defaultdict(set) | ||
|
|
||
| for node in reversed(graph_module.graph.nodes): | ||
| for output_node in node.users: | ||
| # add output_node and output_node's downstream dependency | ||
| self.downstreams[node].add(output_node) | ||
| self.downstreams[node].update(self.downstreams[output_node]) | ||
|
|
||
| for node in graph_module.graph.nodes: | ||
| for input_node in node.all_input_nodes: | ||
| self.upstreams[node].add(input_node) | ||
| self.upstreams[node].update(self.upstreams[input_node]) | ||
|
|
||
| def downstreams_of(self, node: Node) -> set[Node]: | ||
| return self.downstreams[node] | ||
|
|
||
| def upstreams_of(self, node: Node) -> set[Node]: | ||
| return self.upstreams[node] | ||
|
|
||
|
|
||
| class GroupBasedPartitioner(CapabilityBasedPartitioner): | ||
| """ | ||
| A specialized partitioner that extends the CapabilityBasedPartitioner from PyTorch FX. | ||
|
|
||
| GroupBasedPartitioner allows for explicit grouping of nodes into partitions based on | ||
| predefined node groups, while also supporting automatic partitioning for nodes not | ||
| included in any group. Nodes are only allowed to be in one group. | ||
|
|
||
| Features: | ||
| - Explicit Node Grouping: Allows specifying groups of nodes that should be kept together | ||
| in the same partition. | ||
| - Automatic Partitioning: Nodes not included in any explicit group are automatically | ||
| partitioned based on operator support. | ||
| - Cycle Prevention: Ensures that partitioning doesn't create cycles in the execution graph. | ||
| - Single Node Partition Control: Options to allow or disallow single-node partitions, | ||
| with exceptions for specific operations. | ||
|
|
||
| Args: | ||
| graph_module: The FX GraphModule to be partitioned. | ||
| operator_support: Interface to determine if a node is supported by the target backend. | ||
| allows_single_node_partition: Whether to allow single-node partitions. Default: False. | ||
| non_compute_ops: Operations not counted for single-node partition determination. Default: None. | ||
| allowed_single_node_partition_ops: Operations allowed as single-node partitions. Default: None. | ||
| node_groups: Lists of nodes to group together in partitions. Default: None. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| graph_module: GraphModule, | ||
| operator_support: OperatorSupportBase, | ||
| allows_single_node_partition: bool = False, | ||
| non_compute_ops: Optional[Sequence[str]] = None, | ||
| allowed_single_node_partition_ops: Optional[Sequence[str]] = None, | ||
| node_groups: List[List[Node]] = None, | ||
| ) -> None: | ||
| super().__init__( | ||
| graph_module=graph_module, | ||
| operator_support=operator_support, | ||
| allows_single_node_partition=allows_single_node_partition, | ||
| non_compute_ops=non_compute_ops, | ||
| allowed_single_node_partition_ops=allowed_single_node_partition_ops, | ||
| ) | ||
| self.dependency_viewer = _DependencyViewer(graph_module) | ||
| self.node_groups = ( | ||
| [set(node_group) for node_group in node_groups] if node_groups else None | ||
| ) | ||
| self.node_to_group = collections.defaultdict(int) | ||
| self.all_nodes_in_groups = set() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if a node is in two groups? Should we check against that? |
||
| if node_groups: | ||
| for i, group in enumerate(self.node_groups): | ||
| for node in group: | ||
| # Node is in multiple groups - not allowed | ||
| if node in self.node_to_group: | ||
| raise ValueError(f"Node {node} exists in multiple groups.") | ||
| self.node_to_group[node] = i | ||
| self.all_nodes_in_groups.add(node) | ||
|
|
||
| def _can_merge_partitions(self, p1, p2, partitions_by_id): | ||
| """Check if merging two partitions would create a cycle.""" | ||
| p1_nodes = set(partitions_by_id[p1].nodes.keys()) | ||
| p2_nodes = set(partitions_by_id[p2].nodes.keys()) | ||
| combined_nodes = p1_nodes.union(p2_nodes) | ||
|
|
||
| for node in combined_nodes: | ||
| # Get all downstream nodes that are not in the combined partition | ||
| external_downstreams = { | ||
| n | ||
| for n in self.dependency_viewer.downstreams_of(node) | ||
| if n not in combined_nodes | ||
| } | ||
|
|
||
| # Check if any external downstream nodes have downstream nodes in the combined partition | ||
| for external_node in external_downstreams: | ||
| downstream_nodes = self.dependency_viewer.downstreams_of(external_node) | ||
| if any(n in combined_nodes for n in downstream_nodes): | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
| def _process_node_groups( | ||
| self, | ||
| new_partition_id, | ||
| partitions_by_id, | ||
| assignment, | ||
| nodes_order, | ||
| partitions_order, | ||
| partition_users, | ||
| partition_map, | ||
| ): | ||
| """Process nodes in predefined groups.""" | ||
| group_to_partition_id = {} | ||
|
|
||
| if not self.node_groups: | ||
| return group_to_partition_id | ||
|
|
||
| for i, group in enumerate(self.node_groups): | ||
| # Create a partition for each group | ||
| partition_id = next(new_partition_id) | ||
| partition = Partition(id=partition_id, nodes=set()) | ||
| partitions_by_id[partition_id] = partition | ||
| partitions_order[partition_id] = partition_id | ||
| group_to_partition_id[i] = partition_id | ||
|
|
||
| # Add all supported nodes from the group to the partition | ||
| for node in group: | ||
| if self._is_node_supported(node): | ||
| partition.add_node(node) | ||
| assignment[node] = partition_id | ||
| nodes_order[node] = partition_id | ||
|
|
||
| # Set partition users | ||
| partition_users[partition_id] = { | ||
| user | ||
| for node in partition.nodes | ||
| for user in node.users | ||
| if user not in partition.nodes | ||
| } | ||
|
|
||
| # Update partition map | ||
| for node in partition.nodes: | ||
| for user in node.users: | ||
| target_id = assignment.get(user) | ||
| if target_id is not None and target_id != partition_id: | ||
| partition_map[partition_id].add(target_id) | ||
| partition_map[partition_id].update(partition_map[target_id]) | ||
|
|
||
| return group_to_partition_id | ||
|
|
||
| def _process_remaining_nodes( | ||
| self, | ||
| new_partition_id, | ||
| partitions_by_id, | ||
| assignment, | ||
| nodes_order, | ||
| partitions_order, | ||
| partition_users, | ||
| partition_map, | ||
| ): | ||
| """Process nodes not in any predefined group.""" | ||
| for node in reversed(self.graph_module.graph.nodes): | ||
| if node in assignment or not self._is_node_supported(node): | ||
| continue | ||
|
|
||
| partition_id = next(new_partition_id) | ||
| nodes_order[node] = partition_id | ||
| partitions_order[partition_id] = partition_id | ||
| partitions_by_id[partition_id] = Partition(id=partition_id, nodes=[node]) | ||
| assignment[node] = partition_id | ||
| partition_users[partition_id] = set(node.users) | ||
|
|
||
| # Update partition map | ||
| for user in node.users: | ||
| target_id = assignment.get(user) | ||
| if target_id is not None: | ||
| partition_map[partition_id].add(target_id) | ||
| partition_map[partition_id].update(partition_map[target_id]) | ||
|
|
||
| def _merge_partitions( | ||
| self, | ||
| partitions_by_id, | ||
| assignment, | ||
| partition_users, | ||
| partition_map, | ||
| partitions_order, | ||
| ): | ||
| """Merge partitions when possible.""" | ||
| merged = True | ||
| while merged: | ||
| merged = False | ||
| partition_ids = list(partitions_by_id.keys()) | ||
|
|
||
| for i, p1 in enumerate(partition_ids): | ||
| if p1 not in partitions_by_id: | ||
| continue | ||
|
|
||
| for p2 in partition_ids[i + 1 :]: | ||
| if p2 not in partitions_by_id: | ||
| continue | ||
|
|
||
| # Try to merge partitions if it doesn't create cycles | ||
| if self._can_merge_partitions(p1, p2, partitions_by_id): | ||
| self._perform_partition_merge( | ||
| p1, | ||
| p2, | ||
| partitions_by_id, | ||
| assignment, | ||
| partition_users, | ||
| partition_map, | ||
| partitions_order, | ||
| ) | ||
| merged = True | ||
| break | ||
|
|
||
| if merged: | ||
| break | ||
|
|
||
| def _perform_partition_merge( | ||
| self, | ||
| p1, | ||
| p2, | ||
| partitions_by_id, | ||
| assignment, | ||
| partition_users, | ||
| partition_map, | ||
| partitions_order, | ||
| ): | ||
| """Merge partition p2 into p1.""" | ||
| # Merge p2 into p1 | ||
| partitions_by_id[p1].nodes.update(partitions_by_id[p2].nodes) | ||
| for node in partitions_by_id[p2].nodes: | ||
| assignment[node] = p1 | ||
|
|
||
| # Update partition users | ||
| all_users = partition_users[p1] | partition_users[p2] | ||
| all_users.difference_update(partitions_by_id[p1].nodes) | ||
| partition_users[p1] = all_users | ||
|
|
||
| # Update partition map | ||
| partition_map[p1].update(partition_map[p2]) | ||
|
|
||
| # Update partition order | ||
| partitions_order[p1] = min(partitions_order[p1], partitions_order[p2]) | ||
|
|
||
| # Remove p2 | ||
| del partitions_by_id[p2] | ||
| del partition_users[p2] | ||
| del partitions_order[p2] | ||
| if p2 in partition_map: | ||
| del partition_map[p2] | ||
|
|
||
| def _process_getitem_nodes(self, partitions_by_id, assignment): | ||
| """Post-process getitem nodes.""" | ||
| nodes_reassignment = {} | ||
|
|
||
| for node in self.graph_module.graph.nodes: | ||
| # Check if all users are getitem nodes | ||
| is_tuple_output = True | ||
| for user in node.users: | ||
| if ( | ||
| user.op != "call_function" | ||
| or _get_qualified_name(user.target) != "_operator.getitem" | ||
| ): | ||
| is_tuple_output = False | ||
| break | ||
|
|
||
| # Node has tuple outputs, reassign all following getitem nodes into node's partition | ||
| if is_tuple_output: | ||
| id = assignment.get(node, None) | ||
| if id is not None: | ||
| for user in node.users: | ||
| if user in assignment and assignment.get(user, None) != id: | ||
| nodes_reassignment[user] = id | ||
|
|
||
| # Apply reassignments | ||
| for node, id in nodes_reassignment.items(): | ||
| if node in assignment: | ||
| partitions_by_id[assignment[node]].remove_node(node) | ||
|
|
||
| assignment[node] = id | ||
| partitions_by_id[id].add_node(node) | ||
|
|
||
| def _filter_single_node_partitions(self, partitions_by_id): | ||
| """Filter out single node partitions if needed.""" | ||
| if self.allows_single_node_partition: | ||
| return | ||
|
|
||
| default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} | ||
| non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops or [])) | ||
| partitions_to_remove = [] | ||
|
|
||
| for id, partition in partitions_by_id.items(): | ||
| compute_node_count = 0 | ||
| for node in partition.nodes: | ||
| if node.op == "call_function": | ||
| assert callable(node.target) | ||
| target_name = _get_qualified_name(node.target) | ||
|
|
||
| if target_name not in non_compute_ops: | ||
| compute_node_count += 1 | ||
|
|
||
| if ( | ||
| self.allowed_single_node_partition_ops | ||
| and target_name in self.allowed_single_node_partition_ops | ||
| ): | ||
| compute_node_count += 1 | ||
|
|
||
| if compute_node_count <= 1: | ||
| partitions_to_remove.append(id) | ||
|
|
||
| for id in partitions_to_remove: | ||
| del partitions_by_id[id] | ||
|
|
||
| def propose_partitions(self) -> list[Partition]: | ||
| """ | ||
| Propose partitions for the graph module based on node groups and operator support. | ||
|
|
||
| Returns: | ||
| A list of proposed partitions. | ||
| """ | ||
| # Initialize data structures | ||
| partition_map = collections.defaultdict( | ||
| set | ||
| ) # Maps partition IDs to reachable partition IDs | ||
| assignment = {} # Maps nodes to partition IDs | ||
| partitions_by_id = {} # Maps partition IDs to partitions | ||
| nodes_order = {} # Maps nodes to topological order | ||
| partitions_order = {} # Maps partition IDs to minimum topological order | ||
| partition_users = {} # Maps partition IDs to partition users | ||
| new_partition_id = itertools.count() | ||
|
|
||
| # Process nodes in predefined groups | ||
| self._process_node_groups( | ||
| new_partition_id, | ||
| partitions_by_id, | ||
| assignment, | ||
| nodes_order, | ||
| partitions_order, | ||
| partition_users, | ||
| partition_map, | ||
| ) | ||
|
|
||
| # Process remaining nodes | ||
| self._process_remaining_nodes( | ||
| new_partition_id, | ||
| partitions_by_id, | ||
| assignment, | ||
| nodes_order, | ||
| partitions_order, | ||
| partition_users, | ||
| partition_map, | ||
| ) | ||
|
|
||
| # Merge partitions when possible | ||
| self._merge_partitions( | ||
| partitions_by_id, | ||
| assignment, | ||
| partition_users, | ||
| partition_map, | ||
| partitions_order, | ||
| ) | ||
|
|
||
| # Post-process getitem nodes | ||
| self._process_getitem_nodes(partitions_by_id, assignment) | ||
|
|
||
| # Filter single node partitions if needed | ||
| self._filter_single_node_partitions(partitions_by_id) | ||
|
|
||
| # Return non-empty partitions | ||
| return [p for p in partitions_by_id.values() if p.size() > 0] | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Add a doc script