|
10 | 10 | import torch |
11 | 11 | from executorch.exir.backend.backend_details import ExportedProgram |
12 | 12 | from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( |
13 | | - generate_partitions_from_list_of_nodes, |
| 13 | + generate_grouped_partitions_from_list_of_nodes, |
14 | 14 | ) |
15 | 15 | from executorch.exir.backend.partitioner import ( |
16 | 16 | DelegationSpec, |
17 | 17 | Partitioner, |
18 | 18 | PartitionResult, |
19 | 19 | ) |
| 20 | + |
| 21 | +from exir.backend.canonical_partitioners.pattern_op_partitioner import ( |
| 22 | + generate_grouped_partitions_from_list_of_nodes, |
| 23 | +) |
20 | 24 | from torch.fx.passes.infra.partitioner import Partition |
21 | 25 |
|
22 | 26 |
|
@@ -162,23 +166,49 @@ def filter_fn(node: torch.fx.Node) -> bool: |
162 | 166 | def get_matched_nodes_from_configs( |
163 | 167 | self, ep: ExportedProgram |
164 | 168 | ) -> List[List[torch.fx.Node]]: |
| 169 | + # disjoint set union |
| 170 | + parent = {} |
| 171 | + |
| 172 | + def find(x): |
| 173 | + parent.setdefault(x, x) |
| 174 | + if parent[x] != x: |
| 175 | + parent[x] = find(parent[x]) |
| 176 | + return parent[x] |
| 177 | + |
| 178 | + def union(x, y): |
| 179 | + parent[find(x)] = find(y) |
| 180 | + |
165 | 181 | # gather supported nodes |
166 | | - matched_nodes = [] |
167 | 182 | gm = ep.graph_module |
168 | 183 | for node in gm.graph.nodes: |
169 | | - if node.op == "call_function": |
170 | | - target = format_target_name(node.target.__name__) |
171 | | - if target in self.target_partitioner_configs: |
172 | | - node_config = self.target_partitioner_configs[target] |
173 | | - if node_config.check_constraints(node, ep): |
174 | | - matched_nodes.append(node_config.get_partition(node, ep)) |
| 184 | + if node.op != "call_function": |
| 185 | + continue |
| 186 | + target = format_target_name(node.target.__name__) |
| 187 | + |
| 188 | + if target not in self.target_partitioner_configs: |
| 189 | + continue |
| 190 | + |
| 191 | + node_config = self.target_partitioner_configs[target] |
| 192 | + if not node_config.check_constraints(node, ep): |
| 193 | + continue |
| 194 | + |
| 195 | + partition = node_config.get_partition(node, ep) |
| 196 | + if len(partition) > 0: |
| 197 | + parent[partition[0]] = partition[0] |
| 198 | + for i in range(1, len(partition)): |
| 199 | + union(partition[0], partition[i]) |
| 200 | + |
| 201 | + groups = {} |
| 202 | + for node in parent.keys(): |
| 203 | + root = find(node) |
| 204 | + groups.setdefault(root, set()).add(node) |
175 | 205 |
|
176 | | - return matched_nodes |
| 206 | + return [list(group) for group in groups.values()] |
177 | 207 |
|
178 | 208 | def generate_partitions(self, ep: ExportedProgram) -> List[Partition]: |
179 | 209 | matched_nodes = self.get_matched_nodes_from_configs(ep) |
180 | 210 | # create partitions |
181 | | - partitions = generate_partitions_from_list_of_nodes( |
| 211 | + partitions = generate_grouped_partitions_from_list_of_nodes( |
182 | 212 | ep.graph_module, |
183 | 213 | matched_nodes, |
184 | 214 | ) |
|
0 commit comments