Skip to content

Commit 4e50966

Browse files
authored
Pass Dependencies When Proposing Partitions VIA New Group-Based Partitioner (#12072)
### Summary The existing capability-based partitioner in PyTorch's FX module lacks the ability to specify node dependencies that must be partitioned together. This can lead to incorrect partitioning of dynamically quantized linear patterns, resulting in the loss of quantization semantics. For example, in a graph with shared QDQ (Quantize-Dequantize) chains, the partitioner may incorrectly separate nodes that should remain together, leading to incorrect execution semantics. This PR addresses this issue by adding a new group-based partitioner that allows users to specify groups of nodes that must be partitioned together. ### Test plan I've created test cases that replicate existing QDQ tests, as well as additional graphs with different node dependencies. These tests aim to verify that the new partitioner correctly groups nodes together based on the specified dependencies.
1 parent 4253387 commit 4e50966

File tree

2 files changed

+2061
-0
lines changed

2 files changed

+2061
-0
lines changed
Lines changed: 389 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
# mypy: allow-untyped-defs
2+
import collections
3+
import itertools
4+
import logging
5+
from collections.abc import Sequence
6+
from typing import List, Optional
7+
8+
from torch.fx.graph_module import GraphModule
9+
from torch.fx.node import _get_qualified_name, Node
10+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
11+
from torch.fx.passes.operator_support import OperatorSupportBase
12+
13+
14+
logger = logging.getLogger(__name__)
15+
logger.setLevel(logging.WARNING)
16+
17+
18+
class _DependencyViewer:
19+
def __init__(self, graph_module: GraphModule):
20+
self.downstreams = collections.defaultdict(set)
21+
self.upstreams = collections.defaultdict(set)
22+
23+
for node in reversed(graph_module.graph.nodes):
24+
for output_node in node.users:
25+
# add output_node and output_node's downstream dependency
26+
self.downstreams[node].add(output_node)
27+
self.downstreams[node].update(self.downstreams[output_node])
28+
29+
for node in graph_module.graph.nodes:
30+
for input_node in node.all_input_nodes:
31+
self.upstreams[node].add(input_node)
32+
self.upstreams[node].update(self.upstreams[input_node])
33+
34+
def downstreams_of(self, node: Node) -> set[Node]:
35+
return self.downstreams[node]
36+
37+
def upstreams_of(self, node: Node) -> set[Node]:
38+
return self.upstreams[node]
39+
40+
41+
class GroupBasedPartitioner(CapabilityBasedPartitioner):
42+
"""
43+
A specialized partitioner that extends the CapabilityBasedPartitioner from PyTorch FX.
44+
45+
GroupBasedPartitioner allows for explicit grouping of nodes into partitions based on
46+
predefined node groups, while also supporting automatic partitioning for nodes not
47+
included in any group. Nodes are only allowed to be in one group.
48+
49+
Features:
50+
- Explicit Node Grouping: Allows specifying groups of nodes that should be kept together
51+
in the same partition.
52+
- Automatic Partitioning: Nodes not included in any explicit group are automatically
53+
partitioned based on operator support.
54+
- Cycle Prevention: Ensures that partitioning doesn't create cycles in the execution graph.
55+
- Single Node Partition Control: Options to allow or disallow single-node partitions,
56+
with exceptions for specific operations.
57+
58+
Args:
59+
graph_module: The FX GraphModule to be partitioned.
60+
operator_support: Interface to determine if a node is supported by the target backend.
61+
allows_single_node_partition: Whether to allow single-node partitions. Default: False.
62+
non_compute_ops: Operations not counted for single-node partition determination. Default: None.
63+
allowed_single_node_partition_ops: Operations allowed as single-node partitions. Default: None.
64+
node_groups: Lists of nodes to group together in partitions. Default: None.
65+
"""
66+
67+
def __init__(
68+
self,
69+
graph_module: GraphModule,
70+
operator_support: OperatorSupportBase,
71+
allows_single_node_partition: bool = False,
72+
non_compute_ops: Optional[Sequence[str]] = None,
73+
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
74+
node_groups: List[List[Node]] = None,
75+
) -> None:
76+
super().__init__(
77+
graph_module=graph_module,
78+
operator_support=operator_support,
79+
allows_single_node_partition=allows_single_node_partition,
80+
non_compute_ops=non_compute_ops,
81+
allowed_single_node_partition_ops=allowed_single_node_partition_ops,
82+
)
83+
self.dependency_viewer = _DependencyViewer(graph_module)
84+
self.node_groups = (
85+
[set(node_group) for node_group in node_groups] if node_groups else None
86+
)
87+
self.node_to_group = collections.defaultdict(int)
88+
self.all_nodes_in_groups = set()
89+
if node_groups:
90+
for i, group in enumerate(self.node_groups):
91+
for node in group:
92+
# Node is in multiple groups - not allowed
93+
if node in self.node_to_group:
94+
raise ValueError(f"Node {node} exists in multiple groups.")
95+
self.node_to_group[node] = i
96+
self.all_nodes_in_groups.add(node)
97+
98+
def _can_merge_partitions(self, p1, p2, partitions_by_id):
99+
"""Check if merging two partitions would create a cycle."""
100+
p1_nodes = set(partitions_by_id[p1].nodes.keys())
101+
p2_nodes = set(partitions_by_id[p2].nodes.keys())
102+
combined_nodes = p1_nodes.union(p2_nodes)
103+
104+
for node in combined_nodes:
105+
# Get all downstream nodes that are not in the combined partition
106+
external_downstreams = {
107+
n
108+
for n in self.dependency_viewer.downstreams_of(node)
109+
if n not in combined_nodes
110+
}
111+
112+
# Check if any external downstream nodes have downstream nodes in the combined partition
113+
for external_node in external_downstreams:
114+
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
115+
if any(n in combined_nodes for n in downstream_nodes):
116+
return False
117+
118+
return True
119+
120+
def _process_node_groups(
121+
self,
122+
new_partition_id,
123+
partitions_by_id,
124+
assignment,
125+
nodes_order,
126+
partitions_order,
127+
partition_users,
128+
partition_map,
129+
):
130+
"""Process nodes in predefined groups."""
131+
group_to_partition_id = {}
132+
133+
if not self.node_groups:
134+
return group_to_partition_id
135+
136+
for i, group in enumerate(self.node_groups):
137+
# Create a partition for each group
138+
partition_id = next(new_partition_id)
139+
partition = Partition(id=partition_id, nodes=set())
140+
partitions_by_id[partition_id] = partition
141+
partitions_order[partition_id] = partition_id
142+
group_to_partition_id[i] = partition_id
143+
144+
# Add all supported nodes from the group to the partition
145+
for node in group:
146+
if self._is_node_supported(node):
147+
partition.add_node(node)
148+
assignment[node] = partition_id
149+
nodes_order[node] = partition_id
150+
151+
# Set partition users
152+
partition_users[partition_id] = {
153+
user
154+
for node in partition.nodes
155+
for user in node.users
156+
if user not in partition.nodes
157+
}
158+
159+
# Update partition map
160+
for node in partition.nodes:
161+
for user in node.users:
162+
target_id = assignment.get(user)
163+
if target_id is not None and target_id != partition_id:
164+
partition_map[partition_id].add(target_id)
165+
partition_map[partition_id].update(partition_map[target_id])
166+
167+
return group_to_partition_id
168+
169+
def _process_remaining_nodes(
170+
self,
171+
new_partition_id,
172+
partitions_by_id,
173+
assignment,
174+
nodes_order,
175+
partitions_order,
176+
partition_users,
177+
partition_map,
178+
):
179+
"""Process nodes not in any predefined group."""
180+
for node in reversed(self.graph_module.graph.nodes):
181+
if node in assignment or not self._is_node_supported(node):
182+
continue
183+
184+
partition_id = next(new_partition_id)
185+
nodes_order[node] = partition_id
186+
partitions_order[partition_id] = partition_id
187+
partitions_by_id[partition_id] = Partition(id=partition_id, nodes=[node])
188+
assignment[node] = partition_id
189+
partition_users[partition_id] = set(node.users)
190+
191+
# Update partition map
192+
for user in node.users:
193+
target_id = assignment.get(user)
194+
if target_id is not None:
195+
partition_map[partition_id].add(target_id)
196+
partition_map[partition_id].update(partition_map[target_id])
197+
198+
def _merge_partitions(
199+
self,
200+
partitions_by_id,
201+
assignment,
202+
partition_users,
203+
partition_map,
204+
partitions_order,
205+
):
206+
"""Merge partitions when possible."""
207+
merged = True
208+
while merged:
209+
merged = False
210+
partition_ids = list(partitions_by_id.keys())
211+
212+
for i, p1 in enumerate(partition_ids):
213+
if p1 not in partitions_by_id:
214+
continue
215+
216+
for p2 in partition_ids[i + 1 :]:
217+
if p2 not in partitions_by_id:
218+
continue
219+
220+
# Try to merge partitions if it doesn't create cycles
221+
if self._can_merge_partitions(p1, p2, partitions_by_id):
222+
self._perform_partition_merge(
223+
p1,
224+
p2,
225+
partitions_by_id,
226+
assignment,
227+
partition_users,
228+
partition_map,
229+
partitions_order,
230+
)
231+
merged = True
232+
break
233+
234+
if merged:
235+
break
236+
237+
def _perform_partition_merge(
238+
self,
239+
p1,
240+
p2,
241+
partitions_by_id,
242+
assignment,
243+
partition_users,
244+
partition_map,
245+
partitions_order,
246+
):
247+
"""Merge partition p2 into p1."""
248+
# Merge p2 into p1
249+
partitions_by_id[p1].nodes.update(partitions_by_id[p2].nodes)
250+
for node in partitions_by_id[p2].nodes:
251+
assignment[node] = p1
252+
253+
# Update partition users
254+
all_users = partition_users[p1] | partition_users[p2]
255+
all_users.difference_update(partitions_by_id[p1].nodes)
256+
partition_users[p1] = all_users
257+
258+
# Update partition map
259+
partition_map[p1].update(partition_map[p2])
260+
261+
# Update partition order
262+
partitions_order[p1] = min(partitions_order[p1], partitions_order[p2])
263+
264+
# Remove p2
265+
del partitions_by_id[p2]
266+
del partition_users[p2]
267+
del partitions_order[p2]
268+
if p2 in partition_map:
269+
del partition_map[p2]
270+
271+
def _process_getitem_nodes(self, partitions_by_id, assignment):
272+
"""Post-process getitem nodes."""
273+
nodes_reassignment = {}
274+
275+
for node in self.graph_module.graph.nodes:
276+
# Check if all users are getitem nodes
277+
is_tuple_output = True
278+
for user in node.users:
279+
if (
280+
user.op != "call_function"
281+
or _get_qualified_name(user.target) != "_operator.getitem"
282+
):
283+
is_tuple_output = False
284+
break
285+
286+
# Node has tuple outputs, reassign all following getitem nodes into node's partition
287+
if is_tuple_output:
288+
id = assignment.get(node, None)
289+
if id is not None:
290+
for user in node.users:
291+
if user in assignment and assignment.get(user, None) != id:
292+
nodes_reassignment[user] = id
293+
294+
# Apply reassignments
295+
for node, id in nodes_reassignment.items():
296+
if node in assignment:
297+
partitions_by_id[assignment[node]].remove_node(node)
298+
299+
assignment[node] = id
300+
partitions_by_id[id].add_node(node)
301+
302+
def _filter_single_node_partitions(self, partitions_by_id):
303+
"""Filter out single node partitions if needed."""
304+
if self.allows_single_node_partition:
305+
return
306+
307+
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
308+
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops or []))
309+
partitions_to_remove = []
310+
311+
for id, partition in partitions_by_id.items():
312+
compute_node_count = 0
313+
for node in partition.nodes:
314+
if node.op == "call_function":
315+
assert callable(node.target)
316+
target_name = _get_qualified_name(node.target)
317+
318+
if target_name not in non_compute_ops:
319+
compute_node_count += 1
320+
321+
if (
322+
self.allowed_single_node_partition_ops
323+
and target_name in self.allowed_single_node_partition_ops
324+
):
325+
compute_node_count += 1
326+
327+
if compute_node_count <= 1:
328+
partitions_to_remove.append(id)
329+
330+
for id in partitions_to_remove:
331+
del partitions_by_id[id]
332+
333+
def propose_partitions(self) -> list[Partition]:
334+
"""
335+
Propose partitions for the graph module based on node groups and operator support.
336+
337+
Returns:
338+
A list of proposed partitions.
339+
"""
340+
# Initialize data structures
341+
partition_map = collections.defaultdict(
342+
set
343+
) # Maps partition IDs to reachable partition IDs
344+
assignment = {} # Maps nodes to partition IDs
345+
partitions_by_id = {} # Maps partition IDs to partitions
346+
nodes_order = {} # Maps nodes to topological order
347+
partitions_order = {} # Maps partition IDs to minimum topological order
348+
partition_users = {} # Maps partition IDs to partition users
349+
new_partition_id = itertools.count()
350+
351+
# Process nodes in predefined groups
352+
self._process_node_groups(
353+
new_partition_id,
354+
partitions_by_id,
355+
assignment,
356+
nodes_order,
357+
partitions_order,
358+
partition_users,
359+
partition_map,
360+
)
361+
362+
# Process remaining nodes
363+
self._process_remaining_nodes(
364+
new_partition_id,
365+
partitions_by_id,
366+
assignment,
367+
nodes_order,
368+
partitions_order,
369+
partition_users,
370+
partition_map,
371+
)
372+
373+
# Merge partitions when possible
374+
self._merge_partitions(
375+
partitions_by_id,
376+
assignment,
377+
partition_users,
378+
partition_map,
379+
partitions_order,
380+
)
381+
382+
# Post-process getitem nodes
383+
self._process_getitem_nodes(partitions_by_id, assignment)
384+
385+
# Filter single node partitions if needed
386+
self._filter_single_node_partitions(partitions_by_id)
387+
388+
# Return non-empty partitions
389+
return [p for p in partitions_by_id.values() if p.size() > 0]

0 commit comments

Comments
 (0)