Skip to content

Commit 7c82bb4

Browse files
committed
working on group based partitioner
1 parent b74c68d commit 7c82bb4

File tree

2 files changed

+641
-0
lines changed

2 files changed

+641
-0
lines changed
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# mypy: allow-untyped-defs
2+
import collections
3+
import itertools
4+
import logging
5+
from collections.abc import Sequence
6+
from typing import List, Optional
7+
8+
from torch.fx.graph_module import GraphModule
9+
from torch.fx.node import _get_qualified_name, Node
10+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
11+
from torch.fx.passes.operator_support import OperatorSupportBase
12+
13+
14+
logger = logging.getLogger(__name__)
15+
logger.setLevel(logging.WARNING)
16+
17+
18+
class _DependencyViewer:
19+
def __init__(self, graph_module: GraphModule):
20+
self.downstreams = collections.defaultdict(set)
21+
self.upstreams = collections.defaultdict(set)
22+
23+
for node in reversed(graph_module.graph.nodes):
24+
for output_node in node.users:
25+
# add output_node and output_node's downstream dependency
26+
self.downstreams[node].add(output_node)
27+
self.downstreams[node].update(self.downstreams[output_node])
28+
29+
for node in graph_module.graph.nodes:
30+
for input_node in node.all_input_nodes:
31+
self.upstreams[node].add(input_node)
32+
self.upstreams[node].update(self.upstreams[input_node])
33+
34+
def downstreams_of(self, node: Node) -> set[Node]:
35+
return self.downstreams[node]
36+
37+
def upstreams_of(self, node: Node) -> set[Node]:
38+
return self.upstreams[node]
39+
40+
41+
class GroupBasedPartitioner(CapabilityBasedPartitioner):
42+
def __init__(
43+
self,
44+
graph_module: GraphModule,
45+
operator_support: OperatorSupportBase,
46+
allows_single_node_partition: bool = False,
47+
non_compute_ops: Optional[Sequence[str]] = None,
48+
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
49+
node_groups: List[List[Node]] = None,
50+
) -> None:
51+
super().__init__(
52+
graph_module=graph_module,
53+
operator_support=operator_support,
54+
allows_single_node_partition=allows_single_node_partition,
55+
non_compute_ops=non_compute_ops,
56+
allowed_single_node_partition_ops=allowed_single_node_partition_ops,
57+
)
58+
self.dependency_viewer = _DependencyViewer(graph_module)
59+
self.node_groups = (
60+
[set(node_group) for node_group in node_groups] if node_groups else None
61+
)
62+
self.node_to_group = collections.defaultdict(int)
63+
self.all_nodes_in_groups = set()
64+
if node_groups:
65+
for i, group in enumerate(self.node_groups):
66+
for node in group:
67+
self.node_to_group[node] = i
68+
self.all_nodes_in_groups.add(node)
69+
70+
def propose_partitions(self) -> list[Partition]:
71+
# partition_map is a mapping from partition id to a set of partition id's.
72+
# The value set contains all the partition ids that can be reached by doing a
73+
# DFS starting from the partition id in the key.
74+
partition_map: dict[int, set] = collections.defaultdict(set)
75+
76+
# assumptions: nodes in candidate list is sorted in topological order
77+
assignment: dict[Node, int] = {} # mapping from node to partition_id
78+
partitions_by_id: dict[int, Partition] = (
79+
{}
80+
) # mapping from partition_id to partition
81+
nodes_order: dict[Node, int] = (
82+
{}
83+
) # mapping from nodes to reversed topological order
84+
partitions_order: dict[int, int] = (
85+
{}
86+
) # mapping from partition_id to minimum topo order of nodes in partition
87+
partition_users: dict[int, set] = (
88+
{}
89+
) # mapping from partition_id to partition users
90+
new_partition_id = itertools.count()
91+
92+
group_to_partition_id = {} # mapping from group id to partition id
93+
94+
# Try to merge partitions that don't create cycles
95+
def can_merge(p1, p2):
96+
# Check if merging would create a cycle
97+
p1_nodes = set(partitions_by_id[p1].nodes.keys())
98+
p2_nodes = set(partitions_by_id[p2].nodes.keys())
99+
100+
# Create a combined set of nodes from both partitions
101+
combined_nodes = p1_nodes.union(p2_nodes)
102+
103+
# For each node in the combined partition, check if any of its external downstream nodes
104+
# have downstream nodes that are in the combined partition
105+
for node in combined_nodes:
106+
# Get all downstream nodes that are not in the combined partition
107+
external_downstreams = {
108+
n
109+
for n in self.dependency_viewer.downstreams_of(node)
110+
if n not in combined_nodes
111+
}
112+
# Check if any of these external downstream nodes have downstream nodes that are in the combined partition
113+
for external_node in external_downstreams:
114+
for downstream_node in self.dependency_viewer.downstreams_of(
115+
external_node
116+
):
117+
if downstream_node in combined_nodes:
118+
return False
119+
120+
return True
121+
122+
# Preprocess nodes to put them in same partition
123+
if self.node_groups:
124+
for i, group in enumerate(self.node_groups):
125+
# Create a partition for each group
126+
partition_id = next(new_partition_id)
127+
partition = Partition(id=partition_id, nodes=set())
128+
partitions_by_id[partition_id] = partition
129+
partitions_order[partition_id] = partition_id
130+
group_to_partition_id[i] = partition_id
131+
132+
# Add all supported nodes from the group to the partition
133+
for node in group:
134+
if self._is_node_supported(node):
135+
partition.add_node(node)
136+
assignment[node] = partition_id
137+
nodes_order[node] = partition_id
138+
139+
# Set partition users
140+
partition_users[partition_id] = {
141+
user
142+
for node in partition.nodes
143+
for user in node.users
144+
if user not in partition.nodes
145+
}
146+
147+
# Update partition map
148+
for node in partition.nodes:
149+
for user in node.users:
150+
target_id = assignment.get(user)
151+
if target_id is not None and target_id != partition_id:
152+
partition_map[partition_id].add(target_id)
153+
partition_map[partition_id].update(partition_map[target_id])
154+
155+
# Process remaining nodes
156+
for node in reversed(self.graph_module.graph.nodes):
157+
if node in assignment or not self._is_node_supported(node):
158+
continue
159+
160+
partition_id = next(new_partition_id)
161+
nodes_order[node] = partition_id
162+
partitions_order[partition_id] = partition_id
163+
partitions_by_id[partition_id] = Partition(id=partition_id, nodes=[node])
164+
assignment[node] = partition_id
165+
partition_users[partition_id] = set(node.users)
166+
167+
# Update partition map
168+
for user in node.users:
169+
target_id = assignment.get(user)
170+
if target_id is not None:
171+
partition_map[partition_id].add(target_id)
172+
partition_map[partition_id].update(partition_map[target_id])
173+
174+
# Merge partitions when possible
175+
merged = True
176+
while merged:
177+
merged = False
178+
partition_ids = list(partitions_by_id.keys())
179+
for i, p1 in enumerate(partition_ids):
180+
if p1 not in partitions_by_id:
181+
continue
182+
183+
for p2 in partition_ids[i + 1 :]:
184+
if p2 not in partitions_by_id:
185+
continue
186+
187+
# Try to merge partitions if it doesn't create cycles
188+
if can_merge(p1, p2):
189+
# Merge p2 into p1
190+
partitions_by_id[p1].nodes.update(partitions_by_id[p2].nodes)
191+
for node in partitions_by_id[p2].nodes:
192+
assignment[node] = p1
193+
194+
# Update partition users
195+
all_users = partition_users[p1] | partition_users[p2]
196+
all_users.difference_update(partitions_by_id[p1].nodes)
197+
partition_users[p1] = all_users
198+
199+
# Update partition map
200+
partition_map[p1].update(partition_map[p2])
201+
202+
# Update partition order
203+
partitions_order[p1] = min(
204+
partitions_order[p1], partitions_order[p2]
205+
)
206+
207+
# Remove p2
208+
del partitions_by_id[p2]
209+
del partition_users[p2]
210+
del partitions_order[p2]
211+
if p2 in partition_map:
212+
del partition_map[p2]
213+
214+
merged = True
215+
break
216+
217+
if merged:
218+
break
219+
220+
# Post-processing for getitem nodes
221+
nodes_reassignment = {}
222+
for node in self.graph_module.graph.nodes:
223+
is_tuple_output = True
224+
for user in node.users:
225+
if (
226+
user.op != "call_function"
227+
or _get_qualified_name(user.target) != "_operator.getitem"
228+
):
229+
is_tuple_output = False
230+
break
231+
232+
# node has tuple outputs, re-assign all following getitem node into node's partition
233+
if is_tuple_output:
234+
id = assignment.get(node, None)
235+
if id is not None:
236+
for user in node.users:
237+
if user in assignment and assignment.get(user, None) != id:
238+
nodes_reassignment[user] = id
239+
240+
for node, id in nodes_reassignment.items():
241+
if node in assignment:
242+
partitions_by_id[assignment[node]].remove_node(node)
243+
244+
assignment[node] = id
245+
partitions_by_id[id].add_node(node)
246+
247+
# Filter single node partitions if needed
248+
if not self.allows_single_node_partition:
249+
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
250+
non_compute_ops = default_non_compute_ops.union(
251+
set(self.non_compute_ops or [])
252+
)
253+
partitions_to_remove = []
254+
for id, partition in partitions_by_id.items():
255+
compute_node_count = 0
256+
for node in partition.nodes:
257+
if node.op == "call_function":
258+
assert callable(node.target)
259+
if _get_qualified_name(node.target) not in non_compute_ops:
260+
compute_node_count += 1
261+
if (
262+
self.allowed_single_node_partition_ops
263+
and _get_qualified_name(node.target)
264+
in self.allowed_single_node_partition_ops
265+
):
266+
compute_node_count += 1
267+
if compute_node_count <= 1:
268+
partitions_to_remove.append(id)
269+
for id in partitions_to_remove:
270+
del partitions_by_id[id]
271+
272+
return [p for p in partitions_by_id.values() if p.size() > 0]

0 commit comments

Comments
 (0)