|
10 | 10 | import torch
|
11 | 11 | from executorch.exir.backend.backend_details import ExportedProgram
|
12 | 12 | from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
|
13 |
| - generate_partitions_from_list_of_nodes, |
| 13 | + generate_grouped_partitions_from_list_of_nodes, |
14 | 14 | )
|
15 | 15 | from executorch.exir.backend.partitioner import (
|
16 | 16 | DelegationSpec,
|
17 | 17 | Partitioner,
|
18 | 18 | PartitionResult,
|
19 | 19 | )
|
| 20 | + |
| 21 | +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param |
20 | 22 | from torch.fx.passes.infra.partitioner import Partition
|
21 | 23 |
|
22 | 24 |
|
| 25 | +def is_constant_data(ep: ExportedProgram, node: torch.fx.Node) -> bool: |
| 26 | + return ( |
| 27 | + is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node) |
| 28 | + ) |
| 29 | + |
| 30 | + |
23 | 31 | def format_target_name(target_name: str) -> str:
|
24 | 32 | """
|
25 | 33 | We remove the dialect name space from the target name. We generally
|
@@ -100,6 +108,35 @@ def get_partition(
|
100 | 108 | pass
|
101 | 109 |
|
102 | 110 |
|
| 111 | +class DSJ: |
| 112 | + """ |
| 113 | + Disjoint set union data structure used to find connected components in the graph. |
| 114 | + """ |
| 115 | + |
| 116 | + def __init__(self): |
| 117 | + self.parent = {} |
| 118 | + |
| 119 | + def find(self, x): |
| 120 | + self.parent.setdefault(x, x) |
| 121 | + if self.parent[x] != x: |
| 122 | + self.parent[x] = self.find(self.parent[x]) |
| 123 | + return self.parent[x] |
| 124 | + |
| 125 | + def union(self, x, y): |
| 126 | + self.parent[self.find(x)] = self.find(y) |
| 127 | + |
| 128 | + def contains(self, x): |
| 129 | + return x in self.parent |
| 130 | + |
| 131 | + def gen_groups(self): |
| 132 | + groups = {} |
| 133 | + for node in self.parent.keys(): |
| 134 | + root = self.find(node) |
| 135 | + groups.setdefault(root, set()).add(node) |
| 136 | + |
| 137 | + return [list(group) for group in groups.values()] |
| 138 | + |
| 139 | + |
103 | 140 | class ConfigerationBasedPartitioner(Partitioner):
|
104 | 141 | def __init__(
|
105 | 142 | self,
|
@@ -162,23 +199,44 @@ def filter_fn(node: torch.fx.Node) -> bool:
|
162 | 199 | def get_matched_nodes_from_configs(
|
163 | 200 | self, ep: ExportedProgram
|
164 | 201 | ) -> List[List[torch.fx.Node]]:
|
| 202 | + # disjoint set union for merging partitions |
| 203 | + dsj = DSJ() |
| 204 | + |
165 | 205 | # gather supported nodes
|
166 |
| - matched_nodes = [] |
167 | 206 | gm = ep.graph_module
|
168 | 207 | for node in gm.graph.nodes:
|
169 |
| - if node.op == "call_function": |
170 |
| - target = format_target_name(node.target.__name__) |
171 |
| - if target in self.target_partitioner_configs: |
172 |
| - node_config = self.target_partitioner_configs[target] |
173 |
| - if node_config.check_constraints(node, ep): |
174 |
| - matched_nodes.append(node_config.get_partition(node, ep)) |
| 208 | + if node.op != "call_function": |
| 209 | + continue |
| 210 | + target = format_target_name(node.target.__name__) |
| 211 | + |
| 212 | + if target not in self.target_partitioner_configs: |
| 213 | + continue |
| 214 | + |
| 215 | + node_config = self.target_partitioner_configs[target] |
| 216 | + if not node_config.check_constraints(node, ep): |
| 217 | + continue |
| 218 | + |
| 219 | + partition_candidate = node_config.get_partition(node, ep) |
| 220 | + partition = [] |
| 221 | + for node in partition_candidate: |
| 222 | + # partitioner infra copies constant data across partitions, so it |
| 223 | + # is ok if this partition doesn't have it |
| 224 | + if is_constant_data(ep, node) and dsj.contains(node): |
| 225 | + continue |
| 226 | + partition.append(node) |
| 227 | + |
| 228 | + # Union overlaps into a single group |
| 229 | + if len(partition) > 0: |
| 230 | + dsj.find(partition[0]) |
| 231 | + for i in range(1, len(partition)): |
| 232 | + dsj.union(partition[0], partition[i]) |
175 | 233 |
|
176 |
| - return matched_nodes |
| 234 | + return dsj.gen_groups() |
177 | 235 |
|
178 | 236 | def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
|
179 | 237 | matched_nodes = self.get_matched_nodes_from_configs(ep)
|
180 | 238 | # create partitions
|
181 |
| - partitions = generate_partitions_from_list_of_nodes( |
| 239 | + partitions = generate_grouped_partitions_from_list_of_nodes( |
182 | 240 | ep.graph_module,
|
183 | 241 | matched_nodes,
|
184 | 242 | )
|
|
0 commit comments