Skip to content

Commit 0608102

Browse files
committed
[XNNPACK][Weights Cache] Use sha256 hash of bytes instead of tensor name
In production use cases, I've become increasingly afraid of the Weights Cache managing weights across multiple models and the potential for collisions on names. Names like "encoder.layer.weight1" are popular names for encoder models, and that name may be reused across many different models. In reality such a tensor found in different models will be different. A way to alleviate such concerns around collisions is to provide a strong hashing guarantee around the tensor's bytes. Namely if we use the sha256 hash of the tensor bytes as the named key we would have much stronger guarantees around the potential of collisions between weights. Additionally this can provide stronger weight deduplication guarantees. For now we use the named key as the only method for deduplicating weights, but if the underlying bytes are the same but the keys are different we won't be able to deduplicate. Using a hash on the underlying bytes as a key though would help with this (though how likely this happens remains to be seen). Regardless i think hashing the bytes will be much safer in the long-term. The draw back is that this adds a guaranteed 64 bytes per weight. On smaller models this might amount to a bit. Open to discuss on whether other hashing algorithms might provide tolerable collision guarantees like: md5_hash. Differential Revision: [D71212509](https://our.internmc.facebook.com/intern/diff/D71212509/) ghstack-source-id: 272289115 Pull Request resolved: #9333
1 parent 2522789 commit 0608102

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import ctypes
8+
import hashlib
89

910
from typing import cast, Dict, List, Optional, Tuple
1011

@@ -582,9 +583,8 @@ def get_serialized_buffer_index(
582583
ctypes.POINTER(array_type),
583584
).contents
584585

585-
named_key = get_tensor_name(self.exported_program, get_attr_node)
586-
if named_key == "":
587-
raise ValueError(f"Tensor from node: {get_attr_node} has no name")
586+
sha256_hash = hashlib.sha256(bytes(array))
587+
named_key = sha256_hash.hexdigest()
588588

589589
size = const_val.untyped_storage().nbytes()
590590
xnn_graph.constant_data.append(

0 commit comments

Comments
 (0)