3
3
# This source code is licensed under the BSD-style license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
+ import hashlib
7
+ from collections import defaultdict
8
+
6
9
import torch
7
10
from executorch .backends .arm ._passes .arm_pass_utils import (
8
11
get_constant_placeholder_kind ,
@@ -21,7 +24,7 @@ class FuseEqualPlaceholdersPass(ExportPass):
21
24
"""
22
25
This pass optimizes memory usage by finding constant placeholders
23
26
pointing to identical tensors and fusing them to one single placeholder
24
- with multiple users.
27
+ with multiple users, using a cache for faster comparison .
25
28
"""
26
29
27
30
def __init__ (self , exported_program : ExportedProgram ):
@@ -30,58 +33,54 @@ def __init__(self, exported_program: ExportedProgram):
30
33
31
34
def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
32
35
modified = False
33
- const_placeholder_nodes = []
34
- for node in graph_module .graph .nodes :
35
- if is_param_node (self .exported_program , node ):
36
- const_placeholder_nodes .append (node )
37
-
38
- while const_placeholder_nodes :
39
36
40
- # Find equal tensors
41
- node1 = const_placeholder_nodes .pop ()
42
- eq_nodes = [node1 ]
43
- tensor1 = get_param_tensor (self .exported_program , node1 )
44
- if tensor1 is None :
37
+ # Build a cache of params: mapping hash_key -> list of (node, tensor)
38
+ hash_buckets = defaultdict (list )
39
+ for node in graph_module .graph .nodes :
40
+ if not is_param_node (self .exported_program , node ):
45
41
continue
42
+ tensor = get_param_tensor (self .exported_program , node )
43
+ if tensor is None :
44
+ continue
45
+ # Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes
46
+ # Ensure tensor is on CPU and contiguous
47
+ t_cpu = tensor .detach ().cpu ().contiguous ()
48
+ data_bytes = t_cpu .numpy ().tobytes ()
49
+ key = (
50
+ str (t_cpu .dtype ),
51
+ tuple (t_cpu .shape ),
52
+ hashlib .sha1 (data_bytes ).hexdigest (),
53
+ )
54
+ hash_buckets [key ].append ((node , t_cpu ))
46
55
47
- for node2 in const_placeholder_nodes :
48
- tensor2 = get_param_tensor (self .exported_program , node2 )
49
- if tensor2 is None :
50
- continue
51
-
52
- if (
53
- tensor1 .dtype == tensor2 .dtype
54
- and tensor1 .shape == tensor2 .shape
55
- and torch .allclose (tensor1 , tensor2 , atol = 1e-08 )
56
- ):
57
- eq_nodes .append (node2 )
56
+ # For each bucket with more than one entry, fuse:
57
+ for nodes_tensors in hash_buckets .values ():
58
+ if len (nodes_tensors ) < 2 :
59
+ continue
58
60
59
- if len (eq_nodes ) > 1 :
60
- common_name = node1 .name + "_common"
61
- common_kind = get_constant_placeholder_kind (
62
- self .exported_program , node1
61
+ # Create a new placeholder from first in list of equal placeholders.
62
+ rep_node , rep_tensor = nodes_tensors [0 ]
63
+ common_name = rep_node .name + "_common"
64
+ common_kind = get_constant_placeholder_kind (self .exported_program , rep_node )
65
+ common_persistent = True
66
+ with graph_module .graph .inserting_before (rep_node ):
67
+ common_node = create_constant_placeholder (
68
+ self .exported_program ,
69
+ graph_module .graph ,
70
+ common_name ,
71
+ common_kind ,
72
+ rep_tensor ,
73
+ common_persistent ,
63
74
)
64
- common_persisten_buffer = True
65
-
66
- with graph_module .graph .inserting_before (node1 ):
67
- common_node = create_constant_placeholder (
68
- self .exported_program ,
69
- graph_module .graph ,
70
- common_name ,
71
- common_kind ,
72
- tensor1 ,
73
- common_persisten_buffer ,
74
- )
75
-
76
- for eq_node in eq_nodes :
77
- eq_node .replace_all_uses_with (common_node )
78
- delete_constant_placeholder (self .exported_program , eq_node )
79
- if eq_node != node1 :
80
- const_placeholder_nodes .remove (eq_node )
81
75
76
+ # Replace uses and delete duplicates
77
+ for node , _ in nodes_tensors :
78
+ node .replace_all_uses_with (common_node )
79
+ delete_constant_placeholder (self .exported_program , node )
82
80
modified = True
83
81
84
82
if modified :
85
83
graph_module .recompile ()
86
84
graph_module = super ().call (graph_module ).graph_module
85
+
87
86
return PassResult (graph_module = graph_module , modified = modified )
0 commit comments