Skip to content

Commit 3067e98

Browse files
authored
Arm backend: Use bucket approach in fuse_equal_placeholder_pass (#13271)
Verified to catch same dupes as before in lstm and mv2. Instead of comparing all placeholders to each other, compute a hash use as key in dictionary. Equal placeholder -> equal key. If an entry in the dictionary has multiple values, we have duplicates. This is a ~O(N) algorithm compared to earlier O(N^2). This can be seen by measuring the speedup for lstm vs. mv2 lstm: 120 placeholders (116 dupes) 0.4s -> 0.3s mv2: 318 placeholders (98 dupes) ~15s -> 0.5s Signed-off-by: Erik Lundell <[email protected]>
1 parent 72580d2 commit 3067e98

File tree

1 file changed

+43
-44
lines changed

1 file changed

+43
-44
lines changed

backends/arm/_passes/fuse_equal_placeholders_pass.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import hashlib
7+
from collections import defaultdict
8+
69
import torch
710
from executorch.backends.arm._passes.arm_pass_utils import (
811
get_constant_placeholder_kind,
@@ -21,7 +24,7 @@ class FuseEqualPlaceholdersPass(ExportPass):
2124
"""
2225
This pass optimizes memory usage by finding constant placeholders
2326
pointing to identical tensors and fusing them to one single placeholder
24-
with multiple users.
27+
with multiple users, using a cache for faster comparison.
2528
"""
2629

2730
def __init__(self, exported_program: ExportedProgram):
@@ -30,58 +33,54 @@ def __init__(self, exported_program: ExportedProgram):
3033

3134
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3235
modified = False
33-
const_placeholder_nodes = []
34-
for node in graph_module.graph.nodes:
35-
if is_param_node(self.exported_program, node):
36-
const_placeholder_nodes.append(node)
37-
38-
while const_placeholder_nodes:
3936

40-
# Find equal tensors
41-
node1 = const_placeholder_nodes.pop()
42-
eq_nodes = [node1]
43-
tensor1 = get_param_tensor(self.exported_program, node1)
44-
if tensor1 is None:
37+
# Build a cache of params: mapping hash_key -> list of (node, tensor)
38+
hash_buckets = defaultdict(list)
39+
for node in graph_module.graph.nodes:
40+
if not is_param_node(self.exported_program, node):
4541
continue
42+
tensor = get_param_tensor(self.exported_program, node)
43+
if tensor is None:
44+
continue
45+
# Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes
46+
# Ensure tensor is on CPU and contiguous
47+
t_cpu = tensor.detach().cpu().contiguous()
48+
data_bytes = t_cpu.numpy().tobytes()
49+
key = (
50+
str(t_cpu.dtype),
51+
tuple(t_cpu.shape),
52+
hashlib.sha1(data_bytes).hexdigest(),
53+
)
54+
hash_buckets[key].append((node, t_cpu))
4655

47-
for node2 in const_placeholder_nodes:
48-
tensor2 = get_param_tensor(self.exported_program, node2)
49-
if tensor2 is None:
50-
continue
51-
52-
if (
53-
tensor1.dtype == tensor2.dtype
54-
and tensor1.shape == tensor2.shape
55-
and torch.allclose(tensor1, tensor2, atol=1e-08)
56-
):
57-
eq_nodes.append(node2)
56+
# For each bucket with more than one entry, fuse:
57+
for nodes_tensors in hash_buckets.values():
58+
if len(nodes_tensors) < 2:
59+
continue
5860

59-
if len(eq_nodes) > 1:
60-
common_name = node1.name + "_common"
61-
common_kind = get_constant_placeholder_kind(
62-
self.exported_program, node1
61+
# Create a new placeholder from first in list of equal placeholders.
62+
rep_node, rep_tensor = nodes_tensors[0]
63+
common_name = rep_node.name + "_common"
64+
common_kind = get_constant_placeholder_kind(self.exported_program, rep_node)
65+
common_persistent = True
66+
with graph_module.graph.inserting_before(rep_node):
67+
common_node = create_constant_placeholder(
68+
self.exported_program,
69+
graph_module.graph,
70+
common_name,
71+
common_kind,
72+
rep_tensor,
73+
common_persistent,
6374
)
64-
common_persisten_buffer = True
65-
66-
with graph_module.graph.inserting_before(node1):
67-
common_node = create_constant_placeholder(
68-
self.exported_program,
69-
graph_module.graph,
70-
common_name,
71-
common_kind,
72-
tensor1,
73-
common_persisten_buffer,
74-
)
75-
76-
for eq_node in eq_nodes:
77-
eq_node.replace_all_uses_with(common_node)
78-
delete_constant_placeholder(self.exported_program, eq_node)
79-
if eq_node != node1:
80-
const_placeholder_nodes.remove(eq_node)
8175

76+
# Replace uses and delete duplicates
77+
for node, _ in nodes_tensors:
78+
node.replace_all_uses_with(common_node)
79+
delete_constant_placeholder(self.exported_program, node)
8280
modified = True
8381

8482
if modified:
8583
graph_module.recompile()
8684
graph_module = super().call(graph_module).graph_module
85+
8786
return PassResult(graph_module=graph_module, modified=modified)

0 commit comments

Comments
 (0)