Skip to content

Commit afefd1d

Browse files
committed
working on group based partitioner
1 parent b74c68d commit afefd1d

File tree

2 files changed

+601
-0
lines changed

2 files changed

+601
-0
lines changed
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# mypy: allow-untyped-defs
2+
import collections
3+
import itertools
4+
import logging
5+
from collections.abc import Sequence
6+
from typing import List, Optional
7+
8+
from torch.fx.graph_module import GraphModule
9+
from torch.fx.node import _get_qualified_name, Node
10+
from torch.fx.passes.operator_support import OperatorSupportBase
11+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
12+
13+
14+
logger = logging.getLogger(__name__)
15+
logger.setLevel(logging.WARNING)
16+
17+
class _DependencyViewer:
18+
def __init__(self, graph_module: GraphModule):
19+
self.downstreams = collections.defaultdict(set)
20+
self.upstreams = collections.defaultdict(set)
21+
22+
for node in reversed(graph_module.graph.nodes):
23+
for output_node in node.users:
24+
# add output_node and output_node's downstream dependency
25+
self.downstreams[node].add(output_node)
26+
self.downstreams[node].update(self.downstreams[output_node])
27+
28+
for node in graph_module.graph.nodes:
29+
for input_node in node.all_input_nodes:
30+
self.upstreams[node].add(input_node)
31+
self.upstreams[node].update(self.upstreams[input_node])
32+
33+
def downstreams_of(self, node: Node) -> set[Node]:
34+
return self.downstreams[node]
35+
36+
def upstreams_of(self, node: Node) -> set[Node]:
37+
return self.upstreams[node]
38+
39+
class GroupBasedPartitioner(CapabilityBasedPartitioner):
40+
def __init__(
41+
self,
42+
graph_module: GraphModule,
43+
operator_support: OperatorSupportBase,
44+
allows_single_node_partition: bool = False,
45+
non_compute_ops: Optional[Sequence[str]] = None,
46+
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
47+
node_groups: List[List[Node]] = None,
48+
) -> None:
49+
super().__init__(
50+
graph_module=graph_module,
51+
operator_support=operator_support,
52+
allows_single_node_partition=allows_single_node_partition,
53+
non_compute_ops=non_compute_ops,
54+
allowed_single_node_partition_ops=allowed_single_node_partition_ops)
55+
print("node groups: ", node_groups)
56+
self.dependency_viewer = _DependencyViewer(graph_module)
57+
self.node_groups = [set(node_group) for node_group in node_groups] if node_groups else None
58+
self.node_to_group = collections.defaultdict(int)
59+
self.all_nodes_in_groups = set()
60+
if node_groups:
61+
for i, group in enumerate(self.node_groups):
62+
for node in group:
63+
self.node_to_group[node] = i
64+
self.all_nodes_in_groups.add(node)
65+
66+
def propose_partitions(self) -> list[Partition]:
67+
# partition_map is a mapping from partition id to a set of partition id's.
68+
# The value set contains all the partition ids that can be reached by doing a
69+
# DFS starting from the partition id in the key.
70+
partition_map: dict[int, set] = collections.defaultdict(set)
71+
72+
# assumptions: nodes in candidate list is sorted in topological order
73+
assignment: dict[Node, int] = {} # mapping from node to partition_id
74+
partitions_by_id: dict[
75+
int, Partition
76+
] = {} # mapping from partition_id to partition
77+
nodes_order: dict[
78+
Node, int
79+
] = {} # mapping from nodes to reversed topological order
80+
partitions_order: dict[
81+
int, int
82+
] = {} # mapping from partition_id to minimum topo order of nodes in partition
83+
partition_users: dict[
84+
int, set
85+
] = {} # mapping from partition_id to partition users
86+
new_partition_id = itertools.count()
87+
88+
group_to_partition_id = {} # mapping from group id to partition id
89+
90+
# Try to merge partitions that don't create cycles
91+
def can_merge(p1, p2):
92+
# Check if merging would create a cycle
93+
p1_nodes = set(partitions_by_id[p1].nodes.keys())
94+
p2_nodes = set(partitions_by_id[p2].nodes.keys())
95+
96+
# Create a combined set of nodes from both partitions
97+
combined_nodes = p1_nodes.union(p2_nodes)
98+
99+
# For each node in the combined partition, check if any of its external downstream nodes
100+
# have downstream nodes that are in the combined partition
101+
for node in combined_nodes:
102+
print("node: ", node)
103+
# Get all downstream nodes that are not in the combined partition
104+
external_downstreams = {
105+
n for n in self.dependency_viewer.downstreams_of(node)
106+
if n not in combined_nodes
107+
}
108+
print("external_downstreams: ", external_downstreams)
109+
# Check if any of these external downstream nodes have downstream nodes that are in the combined partition
110+
for external_node in external_downstreams:
111+
print("downstreams: ", self.dependency_viewer.downstreams_of(external_node))
112+
for downstream_node in self.dependency_viewer.downstreams_of(external_node):
113+
if downstream_node in combined_nodes:
114+
return False
115+
116+
return True
117+
118+
# Preprocess nodes to put them in same partition
119+
if self.node_groups:
120+
for i, group in enumerate(self.node_groups):
121+
# Create a partition for each group
122+
partition_id = next(new_partition_id)
123+
partition = Partition(id=partition_id, nodes=set())
124+
partitions_by_id[partition_id] = partition
125+
partitions_order[partition_id] = partition_id
126+
group_to_partition_id[i] = partition_id
127+
128+
# Add all supported nodes from the group to the partition
129+
for node in group:
130+
if self._is_node_supported(node):
131+
partition.add_node(node)
132+
assignment[node] = partition_id
133+
nodes_order[node] = partition_id
134+
135+
# Set partition users
136+
partition_users[partition_id] = {
137+
user for node in partition.nodes for user in node.users
138+
if user not in partition.nodes
139+
}
140+
141+
# Update partition map
142+
for node in partition.nodes:
143+
for user in node.users:
144+
target_id = assignment.get(user)
145+
if target_id is not None and target_id != partition_id:
146+
partition_map[partition_id].add(target_id)
147+
partition_map[partition_id].update(partition_map[target_id])
148+
149+
# Process remaining nodes
150+
for node in reversed(self.graph_module.graph.nodes):
151+
if node in assignment or not self._is_node_supported(node):
152+
continue
153+
154+
partition_id = next(new_partition_id)
155+
nodes_order[node] = partition_id
156+
partitions_order[partition_id] = partition_id
157+
partitions_by_id[partition_id] = Partition(id=partition_id, nodes=[node])
158+
assignment[node] = partition_id
159+
partition_users[partition_id] = set(node.users)
160+
161+
# Update partition map
162+
for user in node.users:
163+
target_id = assignment.get(user)
164+
if target_id is not None:
165+
partition_map[partition_id].add(target_id)
166+
partition_map[partition_id].update(partition_map[target_id])
167+
168+
# Print out partitions
169+
for partition_id, partition in partitions_by_id.items():
170+
print(f"Partition {partition_id}:")
171+
for node in partition.nodes:
172+
print(f"\t{node.format_node()}")
173+
print()
174+
175+
# Merge partitions when possible
176+
merged = True
177+
while merged:
178+
merged = False
179+
partition_ids = list(partitions_by_id.keys())
180+
for i, p1 in enumerate(partition_ids):
181+
if p1 not in partitions_by_id:
182+
continue
183+
184+
for p2 in partition_ids[i+1:]:
185+
if p2 not in partitions_by_id:
186+
continue
187+
188+
# Try to merge partitions if it doesn't create cycles
189+
if can_merge(p1, p2):
190+
print("merging partitions", p1, p2, "together")
191+
# Merge p2 into p1
192+
partitions_by_id[p1].nodes.update(partitions_by_id[p2].nodes)
193+
for node in partitions_by_id[p2].nodes:
194+
assignment[node] = p1
195+
196+
# Update partition users
197+
all_users = partition_users[p1] | partition_users[p2]
198+
all_users.difference_update(partitions_by_id[p1].nodes)
199+
partition_users[p1] = all_users
200+
201+
# Update partition map
202+
partition_map[p1].update(partition_map[p2])
203+
204+
# Update partition order
205+
partitions_order[p1] = min(partitions_order[p1], partitions_order[p2])
206+
207+
# Remove p2
208+
del partitions_by_id[p2]
209+
del partition_users[p2]
210+
del partitions_order[p2]
211+
if p2 in partition_map:
212+
del partition_map[p2]
213+
214+
merged = True
215+
break
216+
217+
if merged:
218+
break
219+
220+
# Post-processing for getitem nodes
221+
nodes_reassignment = {}
222+
for node in self.graph_module.graph.nodes:
223+
is_tuple_output = True
224+
for user in node.users:
225+
if (
226+
user.op != "call_function"
227+
or _get_qualified_name(user.target) != "_operator.getitem"
228+
):
229+
is_tuple_output = False
230+
break
231+
232+
# node has tuple outputs, re-assign all following getitem node into node's partition
233+
if is_tuple_output:
234+
id = assignment.get(node, None)
235+
if id is not None:
236+
for user in node.users:
237+
if user in assignment and assignment.get(user, None) != id:
238+
nodes_reassignment[user] = id
239+
240+
for node, id in nodes_reassignment.items():
241+
if node in assignment:
242+
partitions_by_id[assignment[node]].remove_node(node)
243+
244+
assignment[node] = id
245+
partitions_by_id[id].add_node(node)
246+
247+
# Filter single node partitions if needed
248+
if not self.allows_single_node_partition:
249+
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
250+
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops or []))
251+
partitions_to_remove = []
252+
for id, partition in partitions_by_id.items():
253+
compute_node_count = 0
254+
for node in partition.nodes:
255+
if node.op == "call_function":
256+
assert callable(node.target)
257+
if _get_qualified_name(node.target) not in non_compute_ops:
258+
compute_node_count += 1
259+
if (
260+
self.allowed_single_node_partition_ops and
261+
_get_qualified_name(node.target)
262+
in self.allowed_single_node_partition_ops
263+
):
264+
compute_node_count += 1
265+
if compute_node_count <= 1:
266+
partitions_to_remove.append(id)
267+
for id in partitions_to_remove:
268+
del partitions_by_id[id]
269+
270+
return [p for p in partitions_by_id.values() if p.size() > 0]

0 commit comments

Comments
 (0)