Skip to content

Commit 9f70bf9

Browse files
committed
[XNNPACK][Weights Cache] Use sha256 hash of bytes instead of tensor name
Pull Request resolved: #9333 In production use cases, I've become increasingly afraid of the Weights Cache managing weights across multiple models and the potential for collisions on names. Names like "encoder.layer.weight1" are popular names for encoder models, and that name may be reused across many different models. In reality such a tensor found in different models will be different. A way to alleviate such concerns around collisions is to provide a strong hashing guarantee around the tensor's bytes. Namely if we use the sha256 hash of the tensor bytes as the named key we would have much stronger guarantees around the potential of collisions between weights. Additionally this can provide stronger weight deduplication guarantees. For now we use the named key as the only method for deduplicating weights, but if the underlying bytes are the same but the keys are different we won't be able to deduplicate. Using a hash on the underlying bytes as a key though would help with this (though how likely this happens remains to be seen). Regardless i think hashing the bytes will be much safer in the long-term. The draw back is that this adds a guaranteed 64 bytes per weight. On smaller models this might amount to a bit. Open to discuss on whether other hashing algorithms might provide tolerable collision guarantees like: md5_hash. ghstack-source-id: 272502584 @exported-using-ghexport Differential Revision: [D71212509](https://our.internmc.facebook.com/intern/diff/D71212509/)
1 parent 5a5fab7 commit 9f70bf9

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import ctypes
8+
import hashlib
89

910
from typing import cast, Dict, List, Optional, Tuple
1011

@@ -34,7 +35,6 @@
3435
check_or_raise,
3536
get_input_node,
3637
get_param_tensor,
37-
get_tensor_name,
3838
is_param_node,
3939
PERM_NCHW_TO_NHWC,
4040
)
@@ -576,15 +576,19 @@ def get_serialized_buffer_index(
576576
if quant_params is not None and quant_params.is_qc4w:
577577
const_val = self.convert_to_qc4w(const_val)
578578

579-
array_type = ctypes.c_char * const_val.untyped_storage().nbytes()
579+
size = const_val.untyped_storage().nbytes()
580+
array_type = ctypes.c_char * size
580581
array = ctypes.cast(
581582
const_val.untyped_storage().data_ptr(),
582583
ctypes.POINTER(array_type),
583584
).contents
584585

585-
named_key = get_tensor_name(self.exported_program, get_attr_node)
586-
if named_key == "":
587-
raise ValueError(f"Tensor from node: {get_attr_node} has no name")
586+
check_or_raise(
587+
size > 0,
588+
f"Serializing constant data node {tensor} but tensor value has no bytes",
589+
)
590+
sha256_hash = hashlib.sha256(bytes(array))
591+
named_key = sha256_hash.hexdigest()
588592

589593
size = const_val.untyped_storage().nbytes()
590594
xnn_graph.constant_data.append(

0 commit comments

Comments
 (0)