|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +import ctypes |
| 8 | +import hashlib |
7 | 9 | import logging |
8 | 10 | import operator |
9 | 11 | from types import NoneType |
|
25 | 27 | is_symint_node, |
26 | 28 | TensorRepr, |
27 | 29 | ) |
| 30 | +from executorch.exir._serialize._named_data_store import NamedDataStore |
28 | 31 | from executorch.exir.backend.utils import DelegateMappingBuilder |
29 | 32 |
|
30 | 33 | from executorch.exir.tensor import TensorSpec |
@@ -56,6 +59,7 @@ def __init__( |
56 | 59 | self.input_ids = [] |
57 | 60 | self.output_ids = [] |
58 | 61 | self.const_tensors = [] |
| 62 | + self.named_data_store = NamedDataStore() |
59 | 63 |
|
60 | 64 | # Mapping from Node to VkValue id |
61 | 65 | self.node_to_value_ids = {} |
@@ -129,8 +133,36 @@ def get_param_tensor(self, node: Node) -> torch.Tensor: |
129 | 133 | def maybe_add_constant_tensor(self, node: Node) -> int: |
130 | 134 | constant_id = -1 |
131 | 135 | if is_param_node(self.program, node): |
132 | | - constant_id = len(self.const_tensors) |
133 | | - self.const_tensors.append(self.get_param_tensor(node)) |
| 136 | + tensor = self.get_param_tensor(node) |
| 137 | + |
| 138 | + # Serialize tensor data to bytes |
| 139 | + tensor = tensor.contiguous() |
| 140 | + size = tensor.untyped_storage().nbytes() |
| 141 | + |
| 142 | + if size > 0: |
| 143 | + array_type = ctypes.c_char * size |
| 144 | + array = ctypes.cast( |
| 145 | + tensor.untyped_storage().data_ptr(), |
| 146 | + ctypes.POINTER(array_type), |
| 147 | + ).contents |
| 148 | + |
| 149 | + # Generate SHA256 hash as the named key |
| 150 | + tensor_bytes = bytes(array) |
| 151 | + sha256_hash = hashlib.sha256(tensor_bytes) |
| 152 | + named_key = sha256_hash.hexdigest() |
| 153 | + |
| 154 | + # Add to named data store with 16-byte alignment (matching XNNPACK) |
| 155 | + self.named_data_store.add_named_data( |
| 156 | + named_key, tensor_bytes, alignment=16 |
| 157 | + ) |
| 158 | + |
| 159 | + # Create VkBytes entry with named_key and set offset to indicate named data usage |
| 160 | + constant_id = len(self.const_tensors) |
| 161 | + self.const_tensors.append((named_key, size)) |
| 162 | + else: |
| 163 | + # Handle empty tensors |
| 164 | + constant_id = len(self.const_tensors) |
| 165 | + self.const_tensors.append(None) |
134 | 166 |
|
135 | 167 | return constant_id |
136 | 168 |
|
|
0 commit comments