diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index 700c7d1b753..358b3085c80 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -85,20 +85,30 @@ def generate_per_op_partitions(self, ep: ExportedProgram) -> List[Partition]: partitions = [] matched_nodes = self.get_matched_nodes_from_configs(ep) partition_id = itertools.count() - nodes_seen = set() + nodes_seen = {} for match in matched_nodes: - match_set = set(match) + # for debug information we map the node to the string form + # of the partition it belongs to + match_map = dict.fromkeys(match, str(match)) # We only create partitions from the first PartitionerConfig match # if a subsequent partitioner match contains the same node, we do # not create a partition for it - if match_set.isdisjoint(nodes_seen): + overlap = match_map.keys() & nodes_seen.keys() + if len(overlap) == 0: partitions.append( Partition( id=next(partition_id), - nodes=match_set, + nodes=match_map.keys(), ) ) - nodes_seen.update(match_set) + nodes_seen.update(match_map) + else: + error_str = f"per_op mode expects no overlaps between partitions but the partition {match_map.keys()} overlaps with the following partitions:\n" + for overlap_node in overlap: + error_str += f"{nodes_seen[overlap_node]}\n" + + raise RuntimeError(error_str) + return partitions