Skip to content

Commit bfb6bda

Browse files
committed
working on group based partitioner
1 parent b74c68d commit bfb6bda

File tree

2 files changed

+668
-0
lines changed

2 files changed

+668
-0
lines changed
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# mypy: allow-untyped-defs
2+
import collections
3+
import itertools
4+
import logging
5+
from collections.abc import Sequence
6+
from typing import List, Optional
7+
8+
from torch.fx.graph_module import GraphModule
9+
from torch.fx.node import _get_qualified_name, Node
10+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
11+
from torch.fx.passes.operator_support import OperatorSupportBase
12+
13+
14+
logger = logging.getLogger(__name__)
15+
logger.setLevel(logging.WARNING)
16+
17+
18+
class _DependencyViewer:
19+
def __init__(self, graph_module: GraphModule):
20+
self.downstreams = collections.defaultdict(set)
21+
self.upstreams = collections.defaultdict(set)
22+
23+
for node in reversed(graph_module.graph.nodes):
24+
for output_node in node.users:
25+
# add output_node and output_node's downstream dependency
26+
self.downstreams[node].add(output_node)
27+
self.downstreams[node].update(self.downstreams[output_node])
28+
29+
for node in graph_module.graph.nodes:
30+
for input_node in node.all_input_nodes:
31+
self.upstreams[node].add(input_node)
32+
self.upstreams[node].update(self.upstreams[input_node])
33+
34+
def downstreams_of(self, node: Node) -> set[Node]:
35+
return self.downstreams[node]
36+
37+
def upstreams_of(self, node: Node) -> set[Node]:
38+
return self.upstreams[node]
39+
40+
41+
class GroupBasedPartitioner(CapabilityBasedPartitioner):
42+
"""
43+
A specialized partitioner that extends the CapabilityBasedPartitioner from PyTorch FX.
44+
45+
GroupBasedPartitioner allows for explicit grouping of nodes into partitions based on
46+
predefined node groups, while also supporting automatic partitioning for nodes not
47+
included in any group. Nodes are only allowed to be in one group.
48+
49+
Features:
50+
- Explicit Node Grouping: Allows specifying groups of nodes that should be kept together
51+
in the same partition.
52+
- Automatic Partitioning: Nodes not included in any explicit group are automatically
53+
partitioned based on operator support.
54+
- Cycle Prevention: Ensures that partitioning doesn't create cycles in the execution graph.
55+
- Single Node Partition Control: Options to allow or disallow single-node partitions,
56+
with exceptions for specific operations.
57+
58+
Args:
59+
graph_module: The FX GraphModule to be partitioned.
60+
operator_support: Interface to determine if a node is supported by the target backend.
61+
allows_single_node_partition: Whether to allow single-node partitions. Default: False.
62+
non_compute_ops: Operations not counted for single-node partition determination. Default: None.
63+
allowed_single_node_partition_ops: Operations allowed as single-node partitions. Default: None.
64+
node_groups: Lists of nodes to group together in partitions. Default: None.
65+
"""
66+
def __init__(
67+
self,
68+
graph_module: GraphModule,
69+
operator_support: OperatorSupportBase,
70+
allows_single_node_partition: bool = False,
71+
non_compute_ops: Optional[Sequence[str]] = None,
72+
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
73+
node_groups: List[List[Node]] = None,
74+
) -> None:
75+
super().__init__(
76+
graph_module=graph_module,
77+
operator_support=operator_support,
78+
allows_single_node_partition=allows_single_node_partition,
79+
non_compute_ops=non_compute_ops,
80+
allowed_single_node_partition_ops=allowed_single_node_partition_ops,
81+
)
82+
self.dependency_viewer = _DependencyViewer(graph_module)
83+
self.node_groups = (
84+
[set(node_group) for node_group in node_groups] if node_groups else None
85+
)
86+
self.node_to_group = collections.defaultdict(int)
87+
self.all_nodes_in_groups = set()
88+
if node_groups:
89+
for i, group in enumerate(self.node_groups):
90+
for node in group:
91+
# Node is in multiple groups - not allowed
92+
if node in self.node_to_group:
93+
raise ValueError(f"Node {node} exists in multiple groups.")
94+
self.node_to_group[node] = i
95+
self.all_nodes_in_groups.add(node)
96+
97+
def propose_partitions(self) -> list[Partition]:
98+
# partition_map is a mapping from partition id to a set of partition id's.
99+
# The value set contains all the partition ids that can be reached by doing a
100+
# DFS starting from the partition id in the key.
101+
partition_map: dict[int, set] = collections.defaultdict(set)
102+
103+
# assumptions: nodes in candidate list is sorted in topological order
104+
assignment: dict[Node, int] = {} # mapping from node to partition_id
105+
partitions_by_id: dict[int, Partition] = (
106+
{}
107+
) # mapping from partition_id to partition
108+
nodes_order: dict[Node, int] = (
109+
{}
110+
) # mapping from nodes to reversed topological order
111+
partitions_order: dict[int, int] = (
112+
{}
113+
) # mapping from partition_id to minimum topo order of nodes in partition
114+
partition_users: dict[int, set] = (
115+
{}
116+
) # mapping from partition_id to partition users
117+
new_partition_id = itertools.count()
118+
119+
group_to_partition_id = {} # mapping from group id to partition id
120+
121+
# Try to merge partitions that don't create cycles
122+
def can_merge(p1, p2):
123+
# Check if merging would create a cycle
124+
p1_nodes = set(partitions_by_id[p1].nodes.keys())
125+
p2_nodes = set(partitions_by_id[p2].nodes.keys())
126+
127+
# Create a combined set of nodes from both partitions
128+
combined_nodes = p1_nodes.union(p2_nodes)
129+
130+
# For each node in the combined partition, check if any of its external downstream nodes
131+
# have downstream nodes that are in the combined partition
132+
for node in combined_nodes:
133+
# Get all downstream nodes that are not in the combined partition
134+
external_downstreams = {
135+
n
136+
for n in self.dependency_viewer.downstreams_of(node)
137+
if n not in combined_nodes
138+
}
139+
# Check if any of these external downstream nodes have downstream nodes that are in the combined partition
140+
for external_node in external_downstreams:
141+
for downstream_node in self.dependency_viewer.downstreams_of(
142+
external_node
143+
):
144+
if downstream_node in combined_nodes:
145+
return False
146+
147+
return True
148+
149+
# Preprocess nodes to put them in same partition
150+
if self.node_groups:
151+
for i, group in enumerate(self.node_groups):
152+
# Create a partition for each group
153+
partition_id = next(new_partition_id)
154+
partition = Partition(id=partition_id, nodes=set())
155+
partitions_by_id[partition_id] = partition
156+
partitions_order[partition_id] = partition_id
157+
group_to_partition_id[i] = partition_id
158+
159+
# Add all supported nodes from the group to the partition
160+
for node in group:
161+
if self._is_node_supported(node):
162+
partition.add_node(node)
163+
assignment[node] = partition_id
164+
nodes_order[node] = partition_id
165+
166+
# Set partition users
167+
partition_users[partition_id] = {
168+
user
169+
for node in partition.nodes
170+
for user in node.users
171+
if user not in partition.nodes
172+
}
173+
174+
# Update partition map
175+
for node in partition.nodes:
176+
for user in node.users:
177+
target_id = assignment.get(user)
178+
if target_id is not None and target_id != partition_id:
179+
partition_map[partition_id].add(target_id)
180+
partition_map[partition_id].update(partition_map[target_id])
181+
182+
# Process remaining nodes
183+
for node in reversed(self.graph_module.graph.nodes):
184+
if node in assignment or not self._is_node_supported(node):
185+
continue
186+
187+
partition_id = next(new_partition_id)
188+
nodes_order[node] = partition_id
189+
partitions_order[partition_id] = partition_id
190+
partitions_by_id[partition_id] = Partition(id=partition_id, nodes=[node])
191+
assignment[node] = partition_id
192+
partition_users[partition_id] = set(node.users)
193+
194+
# Update partition map
195+
for user in node.users:
196+
target_id = assignment.get(user)
197+
if target_id is not None:
198+
partition_map[partition_id].add(target_id)
199+
partition_map[partition_id].update(partition_map[target_id])
200+
201+
# Merge partitions when possible
202+
merged = True
203+
while merged:
204+
merged = False
205+
partition_ids = list(partitions_by_id.keys())
206+
for i, p1 in enumerate(partition_ids):
207+
if p1 not in partitions_by_id:
208+
continue
209+
210+
for p2 in partition_ids[i + 1 :]:
211+
if p2 not in partitions_by_id:
212+
continue
213+
214+
# Try to merge partitions if it doesn't create cycles
215+
if can_merge(p1, p2):
216+
# Merge p2 into p1
217+
partitions_by_id[p1].nodes.update(partitions_by_id[p2].nodes)
218+
for node in partitions_by_id[p2].nodes:
219+
assignment[node] = p1
220+
221+
# Update partition users
222+
all_users = partition_users[p1] | partition_users[p2]
223+
all_users.difference_update(partitions_by_id[p1].nodes)
224+
partition_users[p1] = all_users
225+
226+
# Update partition map
227+
partition_map[p1].update(partition_map[p2])
228+
229+
# Update partition order
230+
partitions_order[p1] = min(
231+
partitions_order[p1], partitions_order[p2]
232+
)
233+
234+
# Remove p2
235+
del partitions_by_id[p2]
236+
del partition_users[p2]
237+
del partitions_order[p2]
238+
if p2 in partition_map:
239+
del partition_map[p2]
240+
241+
merged = True
242+
break
243+
244+
if merged:
245+
break
246+
247+
# Post-processing for getitem nodes
248+
nodes_reassignment = {}
249+
for node in self.graph_module.graph.nodes:
250+
is_tuple_output = True
251+
for user in node.users:
252+
if (
253+
user.op != "call_function"
254+
or _get_qualified_name(user.target) != "_operator.getitem"
255+
):
256+
is_tuple_output = False
257+
break
258+
259+
# node has tuple outputs, re-assign all following getitem node into node's partition
260+
if is_tuple_output:
261+
id = assignment.get(node, None)
262+
if id is not None:
263+
for user in node.users:
264+
if user in assignment and assignment.get(user, None) != id:
265+
nodes_reassignment[user] = id
266+
267+
for node, id in nodes_reassignment.items():
268+
if node in assignment:
269+
partitions_by_id[assignment[node]].remove_node(node)
270+
271+
assignment[node] = id
272+
partitions_by_id[id].add_node(node)
273+
274+
# Filter single node partitions if needed
275+
if not self.allows_single_node_partition:
276+
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
277+
non_compute_ops = default_non_compute_ops.union(
278+
set(self.non_compute_ops or [])
279+
)
280+
partitions_to_remove = []
281+
for id, partition in partitions_by_id.items():
282+
compute_node_count = 0
283+
for node in partition.nodes:
284+
if node.op == "call_function":
285+
assert callable(node.target)
286+
if _get_qualified_name(node.target) not in non_compute_ops:
287+
compute_node_count += 1
288+
if (
289+
self.allowed_single_node_partition_ops
290+
and _get_qualified_name(node.target)
291+
in self.allowed_single_node_partition_ops
292+
):
293+
compute_node_count += 1
294+
if compute_node_count <= 1:
295+
partitions_to_remove.append(id)
296+
for id in partitions_to_remove:
297+
del partitions_by_id[id]
298+
299+
return [p for p in partitions_by_id.values() if p.size() > 0]

0 commit comments

Comments
 (0)