From cdadc5ee9058e0ce54b40e18500184f8d3f23a1d Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Wed, 6 Aug 2025 16:20:29 +0200 Subject: [PATCH] Arm backend: Use bucket approach in fuse_equal_placeholder_pass Verified to catch same dupes as before in lstm and mv2. Instead of comparing all placeholders to each other, compute a hash use as key in dictionary. Equal placeholder -> equal key. If an entry in the dictionary has multiple values, we have duplicates. This is a ~O(N) algorithm compared to earlier O(N^2). This can be seen by measuring the speedup for lstm vs. mv2 lstm: 120 placeholders (116 dupes) 0.4s -> 0.3s mv2: 318 placeholders (98 dupes) ~15s -> 0.5s Signed-off-by: Erik Lundell Change-Id: I7e2e614402488e7f9437cd7db5212c928176ba32 --- .../_passes/fuse_equal_placeholders_pass.py | 87 +++++++++---------- 1 file changed, 43 insertions(+), 44 deletions(-) diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index 664a0f8ea6c..5631e2f32e9 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -3,6 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import hashlib +from collections import defaultdict + import torch from executorch.backends.arm._passes.arm_pass_utils import ( get_constant_placeholder_kind, @@ -21,7 +24,7 @@ class FuseEqualPlaceholdersPass(ExportPass): """ This pass optimizes memory usage by finding constant placeholders pointing to identical tensors and fusing them to one single placeholder - with multiple users. + with multiple users, using a cache for faster comparison. """ def __init__(self, exported_program: ExportedProgram): @@ -30,58 +33,54 @@ def __init__(self, exported_program: ExportedProgram): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: modified = False - const_placeholder_nodes = [] - for node in graph_module.graph.nodes: - if is_param_node(self.exported_program, node): - const_placeholder_nodes.append(node) - - while const_placeholder_nodes: - # Find equal tensors - node1 = const_placeholder_nodes.pop() - eq_nodes = [node1] - tensor1 = get_param_tensor(self.exported_program, node1) - if tensor1 is None: + # Build a cache of params: mapping hash_key -> list of (node, tensor) + hash_buckets = defaultdict(list) + for node in graph_module.graph.nodes: + if not is_param_node(self.exported_program, node): continue + tensor = get_param_tensor(self.exported_program, node) + if tensor is None: + continue + # Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes + # Ensure tensor is on CPU and contiguous + t_cpu = tensor.detach().cpu().contiguous() + data_bytes = t_cpu.numpy().tobytes() + key = ( + str(t_cpu.dtype), + tuple(t_cpu.shape), + hashlib.sha1(data_bytes).hexdigest(), + ) + hash_buckets[key].append((node, t_cpu)) - for node2 in const_placeholder_nodes: - tensor2 = get_param_tensor(self.exported_program, node2) - if tensor2 is None: - continue - - if ( - tensor1.dtype == tensor2.dtype - and tensor1.shape == tensor2.shape - and torch.allclose(tensor1, tensor2, atol=1e-08) - ): - eq_nodes.append(node2) + # For each bucket with more than one entry, fuse: + for nodes_tensors in hash_buckets.values(): + if len(nodes_tensors) < 2: + continue - if len(eq_nodes) > 1: - common_name = node1.name + "_common" - common_kind = get_constant_placeholder_kind( - self.exported_program, node1 + # Create a new placeholder from first in list of equal placeholders. + rep_node, rep_tensor = nodes_tensors[0] + common_name = rep_node.name + "_common" + common_kind = get_constant_placeholder_kind(self.exported_program, rep_node) + common_persistent = True + with graph_module.graph.inserting_before(rep_node): + common_node = create_constant_placeholder( + self.exported_program, + graph_module.graph, + common_name, + common_kind, + rep_tensor, + common_persistent, ) - common_persisten_buffer = True - - with graph_module.graph.inserting_before(node1): - common_node = create_constant_placeholder( - self.exported_program, - graph_module.graph, - common_name, - common_kind, - tensor1, - common_persisten_buffer, - ) - - for eq_node in eq_nodes: - eq_node.replace_all_uses_with(common_node) - delete_constant_placeholder(self.exported_program, eq_node) - if eq_node != node1: - const_placeholder_nodes.remove(eq_node) + # Replace uses and delete duplicates + for node, _ in nodes_tensors: + node.replace_all_uses_with(common_node) + delete_constant_placeholder(self.exported_program, node) modified = True if modified: graph_module.recompile() graph_module = super().call(graph_module).graph_module + return PassResult(graph_module=graph_module, modified=modified)