From 7014fe019dc8eb819439078102cc964e27c072b5 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 5 Dec 2024 13:12:54 -0800 Subject: [PATCH] error in per-op mode Summary: let's explicitly error when per-op mode produces op partitions that overlap, silently not partitioning is a bit more annoying to debug Differential Revision: D66836605 --- .../xnnpack/partition/xnnpack_partitioner.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index 700c7d1b753..358b3085c80 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -85,20 +85,30 @@ def generate_per_op_partitions(self, ep: ExportedProgram) -> List[Partition]: partitions = [] matched_nodes = self.get_matched_nodes_from_configs(ep) partition_id = itertools.count() - nodes_seen = set() + nodes_seen = {} for match in matched_nodes: - match_set = set(match) + # for debug information we map the node to the string form + # of the partition it belongs to + match_map = dict.fromkeys(match, str(match)) # We only create partitions from the first PartitionerConfig match # if a subsequent partitioner match contains the same node, we do # not create a partition for it - if match_set.isdisjoint(nodes_seen): + overlap = match_map.keys() & nodes_seen.keys() + if len(overlap) == 0: partitions.append( Partition( id=next(partition_id), - nodes=match_set, + nodes=match_map.keys(), ) ) - nodes_seen.update(match_set) + nodes_seen.update(match_map) + else: + error_str = f"per_op mode expects no overlaps between partitions but the partition {match_map.keys()} overlaps with the following partitions:\n" + for overlap_node in overlap: + error_str += f"{nodes_seen[overlap_node]}\n" + + raise RuntimeError(error_str) + return partitions