Skip to content

Commit d3f99fa

Browse files
authored
[Group Partitioner] leverage group partitioner for config-based partitioner (#12845)
We use the new group based partitioner in the ConfigerationBasedPartitioner. This solves issues in the XnnpackPartitioner when required dependencies end up in different partitions. For example, consider the following case: <img width="1174" height="1141" alt="separate_example" src="https://github.com/user-attachments/assets/1d1a93e8-8405-41c9-b441-29b53b6c0c12" /> In this case, we have two linear layers sharing the same activation and thus the same dynamically quantized linear chain. With the capability based partitioner, we do greedy partitioning from the bottom up, this means we could end up with something like this <img width="1436" height="1223" alt="bad_partition_example" src="https://github.com/user-attachments/assets/584bc8f5-9b8e-47fb-a269-1ebed428779e" /> This is bad because when we are processing the graph, in the second partition, we lose the semantics of the dynamically quantized tensor! We need the dynamic quant chain the be grouped with the linears. Which is why the XNNPACK Partitioner needs the group based partitioner. This allows us to enforce that dependencies will stay in the same partition, giving us something more correct like such: <img width="1256" height="1414" alt="good_partition_example" src="https://github.com/user-attachments/assets/21b79570-f4d4-4b35-8876-1b981f77d478" /> This ends up resolving the issues we've seen with mobilebert model, and allows us to efficiently partition and lower the model. Dynamically Quantized Mobilebert ``` ./cmake-out/executor_runner --model_path=./mobilebert_xnnpack_q8.pte I 00:00:00.001371 executorch:cpuinfo_utils.cpp:71] Reading file /sys/devices/soc0/image_version I 00:00:00.001409 executorch:cpuinfo_utils.cpp:87] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.001413 executorch:cpuinfo_utils.cpp:167] Number of efficient cores 4 I 00:00:00.001414 executorch:executor_runner.cpp:143] Resetting threadpool with num threads = 6 I 00:00:00.009699 executorch:executor_runner.cpp:166] Model file ./mobilebert_xnnpack_q8.pte is loaded. I 00:00:00.009710 executorch:executor_runner.cpp:175] Using method forward I 00:00:00.009730 executorch:executor_runner.cpp:226] Setting up planned buffer 0, size 77952. I 00:00:00.036463 executorch:executor_runner.cpp:251] Method loaded. I 00:00:00.071498 executorch:executor_runner.cpp:291] Model executed successfully 1 time(s) in 35.018417 ms. I 00:00:00.071513 executorch:executor_runner.cpp:295] 2 outputs: Output 0: tensor(sizes=[1, 8, 512], [ -8.04689e+06, 89921.5, -40037.4, 8.52506e+06, 1.17963e+07, 20380.7, 1.39942e+06, 2.67918e+06, -27719.9, 26655.3, 26278.5, -11431.7, 1.0805e+07, 1.15606e+07, 9.69393e+06, -4.05643e+06, -1.33593e+07, -3.62764e+06, -3.92605e+06, -5.0347e+06, -3.44161e+06, 5.4422e+06, -4.41542e+06, -862129, 3.69221e+06, -6.19857e+06, -61584.5, 8.39651e+06, 348193., -11792.5, 3.33663e+06, 1.04164e+07, -48750., -2.11202e+06, 3.61252e+06, -84356.4, -90260.1, 5.23775e+06, -1.01881e+07, 47533.1, 5426.47, 7.37797e+06, -218896., 11355.8, -1.24047e+07, -7.86736e+06, 1.49692e+07, -63036.1, -1.22408e+07, 5.8747e+06, -1.19913e+07, 4.16419e+06, -365.369, 1.91252e+06, -1.18212e+07, -3.02069e+06, 399647., -1.67848e+07, 3.56225e+06, 3551.07, 7.25625e+06, -1.48738e+06, -3.00221e+06, -3.78693e+06, 3.80128e+06, -41781.2, 60907.9, 3363.56, 331642., -12889.1, -79153.3, 8.15604e+06, -35946., -2.14405e+06, 879050., -84710.1, 22719.9, 1.04667e+07, 3835.81, 7871.89, 599005., -5.18654e+06, 1.10174e+07, -67339.5, -6.46703e+06, 1.13614e+07, -1.1734e+07, 2.26333e+07, -3.73865e+06, 23098.9, 53836.8, -2.14386e+06, 7.16458e+06, -1.20669e+06, -6.47833e+06, -11763.6, 10123.3, 31614.2, 7.28168e+06, 2.71116e+06, ..., -1.86833, -0.0700233, -3.32009, 4.97812, -12.1685, 0.684234, 1.22965, 1.8467, 2.48172, -0.868182, 1.61334, -3.08905, 1.03254, -0.294466, 0.163391, 0.0361963, 0.771725, 0.302791, -0.400353, -2.08169, 0.970273, -1.7616, 1.57219, -3.49633, 1.19427, -0.916265, 2.77638, -1.29021, 2.54229, 1.23152, 0.818117, -2.78617, -1.56857, -0.19215, -0.382113, -0.373299, -0.072007, 2.57036, 0.0108059, -0.111063, -0.29927, 3.42146, -0.000436038, -3.75321, 1.29326, -0.56582, -1.37337, -0.735198, -5.55393, 0.0523185, -3.00903, -0.404585, 1.21914, -0.307003, 1.1404, -0.110441, 0.933819, 0.854603, -3.83357, -0.681134, -1.40674, -3.68943, 2.8351, -1.17661, 2.2165, -2.63289, 2.08129, 2.1289, 1.93094, 3.26524, -1.91472, -0.312142, -1.16881, 1.14951, -1.65103, 2.544, 1.7263, 1.8976, 2.69789, 2.54283, 0.515044, 1.50896, -1.09299, -2.95714, -2.85916, -0.48472, 3.26736, -0.0605457, -2.41002, 0.118062, -1.17784, 0.147574, 1.16962, -3.43538, 2.22663, 1.7344, 6.39607, 0.375988, -1.43199, 2.66983, ]) Output 1: tensor(sizes=[1, 512], [ -8.04689e+06, 89921.5, -40037.4, 8.52506e+06, 1.17963e+07, 20380.7, 1.39942e+06, 2.67918e+06, -27719.9, 26655.3, 26278.5, -11431.7, 1.0805e+07, 1.15606e+07, 9.69393e+06, -4.05643e+06, -1.33593e+07, -3.62764e+06, -3.92605e+06, -5.0347e+06, -3.44161e+06, 5.4422e+06, -4.41542e+06, -862129, 3.69221e+06, -6.19857e+06, -61584.5, 8.39651e+06, 348193., -11792.5, 3.33663e+06, 1.04164e+07, -48750., -2.11202e+06, 3.61252e+06, -84356.4, -90260.1, 5.23775e+06, -1.01881e+07, 47533.1, 5426.47, 7.37797e+06, -218896., 11355.8, -1.24047e+07, -7.86736e+06, 1.49692e+07, -63036.1, -1.22408e+07, 5.8747e+06, -1.19913e+07, 4.16419e+06, -365.369, 1.91252e+06, -1.18212e+07, -3.02069e+06, 399647., -1.67848e+07, 3.56225e+06, 3551.07, 7.25625e+06, -1.48738e+06, -3.00221e+06, -3.78693e+06, 3.80128e+06, -41781.2, 60907.9, 3363.56, 331642., -12889.1, -79153.3, 8.15604e+06, -35946., -2.14405e+06, 879050., -84710.1, 22719.9, 1.04667e+07, 3835.81, 7871.89, 599005., -5.18654e+06, 1.10174e+07, -67339.5, -6.46703e+06, 1.13614e+07, -1.1734e+07, 2.26333e+07, -3.73865e+06, 23098.9, 53836.8, -2.14386e+06, 7.16458e+06, -1.20669e+06, -6.47833e+06, -11763.6, 10123.3, 31614.2, 7.28168e+06, 2.71116e+06, ..., -1.17233e+07, 3.56664e+06, 526460., 15408.2, 2.66414e+06, -3.51814e+06, 4.95537e+06, 1.09744e+07, -5.29623e+06, -8.26713e+06, -1.2175e+07, 814321, -7.74141e+06, -5.69845e+06, -2.31804e+06, -28509.9, -6.9845e+06, 48434.7, -1.73455e+06, 6.77975e+06, -484375., 3.95481e+06, 227819., -5.05215e+06, 81264.2, 62764.7, -23639.9, 4.06676e+06, 3.27637e+06, 671378, 25933., 124962., 53814.2, 5592.71, 56538.1, -11916.6, 844411., -73856.8, 112870, -9964.54, -101818, 208721, -1.06023e+07, -23943.8, 59535.2, -6.14167e+06, 1.23486e+06, 55959.9, 5.76443e+06, -6.79684e+06, -7358.8, -167987., -54665.9, 139637, -5.28488e+06, 851829., 65804.1, -82475.9, -3.84694e+06, -72075.1, 6355.58, -18715, -8.31068e+06, -1.0867e+07, 352061, -210723, 31435.9, 633050, -27808.4, 2.27589e+07, -4.18585e+06, 15830.1, 51673.7, 314638., 4.43424e+06, 5.11579e+06, 26015.8, 2.66392e+06, -39106.2, 6.72089e+06, -2.15694e+06, 12370., 1.24608e+07, 4.71738e+06, -3.48226e+06, 2.57585e+06, -8.45312e+06, -62809.9, 232844, 27969.1, 1.17774e+07, 2.5152e+06, -25153.8, -278463., -751206., -5487.12, -3.68934e+06, -682197., 417969, -21151.1, ]) ``` Differential Revision: [D79020721](https://our.internmc.facebook.com/intern/diff/D79020721)
1 parent de89397 commit d3f99fa

File tree

3 files changed

+117
-10
lines changed

3 files changed

+117
-10
lines changed

exir/backend/canonical_partitioners/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ runtime.python_library(
1818
deps = [
1919
"//caffe2:torch",
2020
"//executorch/exir/backend:partitioner",
21+
":group_partitioner_lib",
2122
],
2223
)
2324

exir/backend/canonical_partitioners/config_partitioner.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,24 @@
1010
import torch
1111
from executorch.exir.backend.backend_details import ExportedProgram
1212
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
13-
generate_partitions_from_list_of_nodes,
13+
generate_grouped_partitions_from_list_of_nodes,
1414
)
1515
from executorch.exir.backend.partitioner import (
1616
DelegationSpec,
1717
Partitioner,
1818
PartitionResult,
1919
)
20+
21+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2022
from torch.fx.passes.infra.partitioner import Partition
2123

2224

25+
def is_constant_data(ep: ExportedProgram, node: torch.fx.Node) -> bool:
26+
return (
27+
is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node)
28+
)
29+
30+
2331
def format_target_name(target_name: str) -> str:
2432
"""
2533
We remove the dialect name space from the target name. We generally
@@ -100,6 +108,35 @@ def get_partition(
100108
pass
101109

102110

111+
class DSJ:
112+
"""
113+
Disjoint set union data structure used to find connected components in the graph.
114+
"""
115+
116+
def __init__(self):
117+
self.parent = {}
118+
119+
def find(self, x):
120+
self.parent.setdefault(x, x)
121+
if self.parent[x] != x:
122+
self.parent[x] = self.find(self.parent[x])
123+
return self.parent[x]
124+
125+
def union(self, x, y):
126+
self.parent[self.find(x)] = self.find(y)
127+
128+
def contains(self, x):
129+
return x in self.parent
130+
131+
def gen_groups(self):
132+
groups = {}
133+
for node in self.parent.keys():
134+
root = self.find(node)
135+
groups.setdefault(root, set()).add(node)
136+
137+
return [list(group) for group in groups.values()]
138+
139+
103140
class ConfigerationBasedPartitioner(Partitioner):
104141
def __init__(
105142
self,
@@ -162,23 +199,44 @@ def filter_fn(node: torch.fx.Node) -> bool:
162199
def get_matched_nodes_from_configs(
163200
self, ep: ExportedProgram
164201
) -> List[List[torch.fx.Node]]:
202+
# disjoint set union for merging partitions
203+
dsj = DSJ()
204+
165205
# gather supported nodes
166-
matched_nodes = []
167206
gm = ep.graph_module
168207
for node in gm.graph.nodes:
169-
if node.op == "call_function":
170-
target = format_target_name(node.target.__name__)
171-
if target in self.target_partitioner_configs:
172-
node_config = self.target_partitioner_configs[target]
173-
if node_config.check_constraints(node, ep):
174-
matched_nodes.append(node_config.get_partition(node, ep))
208+
if node.op != "call_function":
209+
continue
210+
target = format_target_name(node.target.__name__)
211+
212+
if target not in self.target_partitioner_configs:
213+
continue
214+
215+
node_config = self.target_partitioner_configs[target]
216+
if not node_config.check_constraints(node, ep):
217+
continue
218+
219+
partition_candidate = node_config.get_partition(node, ep)
220+
partition = []
221+
for node in partition_candidate:
222+
# partitioner infra copies constant data across partitions, so it
223+
# is ok if this partition doesn't have it
224+
if is_constant_data(ep, node) and dsj.contains(node):
225+
continue
226+
partition.append(node)
227+
228+
# Union overlaps into a single group
229+
if len(partition) > 0:
230+
dsj.find(partition[0])
231+
for i in range(1, len(partition)):
232+
dsj.union(partition[0], partition[i])
175233

176-
return matched_nodes
234+
return dsj.gen_groups()
177235

178236
def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
179237
matched_nodes = self.get_matched_nodes_from_configs(ep)
180238
# create partitions
181-
partitions = generate_partitions_from_list_of_nodes(
239+
partitions = generate_grouped_partitions_from_list_of_nodes(
182240
ep.graph_module,
183241
matched_nodes,
184242
)

exir/backend/canonical_partitioners/pattern_op_partitioner.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from typing import List, Optional
99

1010
import torch
11+
12+
from executorch.exir.backend.canonical_partitioners.group_partitioner import (
13+
GroupBasedPartitioner,
14+
)
1115
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
1216
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
1317
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
@@ -56,6 +60,50 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5660
return partition_list
5761

5862

63+
def generate_grouped_partitions_from_list_of_nodes(
64+
graph_module: torch.fx.GraphModule,
65+
pattern_list: Optional[List[List[torch.fx.Node]]] = None,
66+
op_support: Optional[OperatorSupportBase] = None,
67+
) -> List[Partition]:
68+
final_op_support: Optional[OperatorSupportBase] = op_support
69+
70+
if pattern_list is not None:
71+
# Tag all the nodes in these patterns
72+
for node_list in pattern_list:
73+
for node in node_list:
74+
node.meta["match"] = True
75+
76+
class MatchTag(OperatorSupportBase):
77+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
78+
return node.meta.get("match", False)
79+
80+
final_op_support = (
81+
MatchTag()
82+
if final_op_support is None
83+
else any_chain(final_op_support, MatchTag())
84+
)
85+
86+
assert (
87+
final_op_support is not None
88+
), "Did not give a pattern or OperatorSupportBase instance to partition with"
89+
90+
# Run the CapabilityBasedPartitioner to return the largest possible
91+
# subgraphs containing the nodes with the tags
92+
group_partitioner = GroupBasedPartitioner(
93+
graph_module,
94+
final_op_support,
95+
node_groups=pattern_list,
96+
allows_single_node_partition=True,
97+
)
98+
partition_list = group_partitioner.propose_partitions()
99+
100+
# Remove the metadata field we added
101+
for partition in partition_list:
102+
for node in partition.nodes:
103+
node.meta.pop("match", False)
104+
return partition_list
105+
106+
59107
def generate_pattern_op_partitions(
60108
graph_module: torch.fx.GraphModule,
61109
patterns: Optional[List[torch.fx.Graph]] = None,

0 commit comments

Comments
 (0)