Skip to content

Commit 29a8612

Browse files
authored
[ET-VK][AOT] Serialize constant tensors via NamedDataMap (#13499)
Summary: When exporting models to Vulkan backend, save constant tensors in the NamedDataMap instead of the constant data section of the delegate header. ## Motivation Prevent screen blackout (Llama 3.2 1B) / device crash (Llama 3.2 3B) when running Llama 3.2 models on Samsung Galaxy S24. This behaviour is related to high peak memory usage when loading the model. For more information, see the top diff/PR in the stack. ## Context This change is based on the equivalent change D70315207/#9153 in XNNPACK. Test Plan: ## Memory Comparison with/without NamedDataMap Measured VmRss using ``` uint64_t getVmRssInKB() { std::ifstream statusFile("/proc/self/status"); std::string l, num; while (std::getline(statusFile, l)) { if (l.substr(0, 5) == "VmRSS") { size_t pos = l.find_first_of("0123456789"); num = l.substr(pos); break; } } uint64_t vmRssInKB = std::stoi(num); return vmRssInKB; } ``` P1908019767 (Meta only) Excerpt: ``` Log 1 | Log 2 --------------------------------------------------|-------------------------------------------------- Memory usage before model compilation: 1115416 KB | Memory usage before model compilation: 1919228 KB Memory usage after graph building: 1924340 KB | Memory usage after graph building: 1924256 KB Memory usage after graph preparation: 1798968 KB | Memory usage after graph preparation: 1782464 KB Memory usage prepack start: 1798968 KB | Memory usage prepack start: 1781968 KB Memory usage after prepack operations: 1271924 KB | Memory usage after prepack operations: 1653496 KB ``` Differential Revision: [D80460034](https://our.internmc.facebook.com/intern/diff/D80460034) [ghstack-poisoned]
1 parent 3c4cabb commit 29a8612

File tree

3 files changed

+51
-5
lines changed

3 files changed

+51
-5
lines changed

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import ctypes
8+
import hashlib
79
import logging
810
import operator
911
from types import NoneType
@@ -25,6 +27,7 @@
2527
is_symint_node,
2628
TensorRepr,
2729
)
30+
from executorch.exir._serialize._named_data_store import NamedDataStore
2831
from executorch.exir.backend.utils import DelegateMappingBuilder
2932

3033
from executorch.exir.tensor import TensorSpec
@@ -56,6 +59,7 @@ def __init__(
5659
self.input_ids = []
5760
self.output_ids = []
5861
self.const_tensors = []
62+
self.named_data_store = NamedDataStore()
5963

6064
# Mapping from Node to VkValue id
6165
self.node_to_value_ids = {}
@@ -129,8 +133,36 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
129133
def maybe_add_constant_tensor(self, node: Node) -> int:
130134
constant_id = -1
131135
if is_param_node(self.program, node):
132-
constant_id = len(self.const_tensors)
133-
self.const_tensors.append(self.get_param_tensor(node))
136+
tensor = self.get_param_tensor(node)
137+
138+
# Serialize tensor data to bytes
139+
tensor = tensor.contiguous()
140+
size = tensor.untyped_storage().nbytes()
141+
142+
if size > 0:
143+
array_type = ctypes.c_char * size
144+
array = ctypes.cast(
145+
tensor.untyped_storage().data_ptr(),
146+
ctypes.POINTER(array_type),
147+
).contents
148+
149+
# Generate SHA256 hash as the named key
150+
tensor_bytes = bytes(array)
151+
sha256_hash = hashlib.sha256(tensor_bytes)
152+
named_key = sha256_hash.hexdigest()
153+
154+
# Add to named data store with 16-byte alignment (matching XNNPACK)
155+
self.named_data_store.add_named_data(
156+
named_key, tensor_bytes, alignment=16
157+
)
158+
159+
# Create VkBytes entry with named_key and set offset to indicate named data usage
160+
constant_id = len(self.const_tensors)
161+
self.const_tensors.append((named_key, size))
162+
else:
163+
# Handle empty tensors
164+
constant_id = len(self.const_tensors)
165+
self.const_tensors.append(None)
134166

135167
return constant_id
136168

backends/vulkan/serialization/vulkan_graph_serialize.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,21 @@ def serialize_constant_tensors(
191191

192192
current_offset = len(raw_bytes)
193193
for tensor in const_tensors:
194-
if tensor.numel() == 0:
194+
# The tensor data is stored in the named data map
195+
if isinstance(tensor, tuple):
196+
named_key, size = tensor
197+
vk_graph.constants.append(
198+
VkBytes(
199+
offset=18446744073709551615, # UINT64_MAX to indicate named data
200+
length=size,
201+
named_key=named_key,
202+
)
203+
)
204+
elif tensor is None or (
205+
isinstance(tensor, torch.Tensor) and tensor.numel() == 0
206+
):
195207
vk_graph.constants.append(VkBytes(current_offset, 0))
196-
continue
197-
else:
208+
elif isinstance(tensor, torch.Tensor):
198209
array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
199210
array = ctypes.cast(
200211
tensor.untyped_storage().data_ptr(),
@@ -208,6 +219,8 @@ def serialize_constant_tensors(
208219

209220
vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes)))
210221
current_offset += aligned_size(len(tensor_bytes))
222+
else:
223+
raise ValueError(f"Unsupported constant tensor type: {type(tensor)}")
211224

212225

213226
def serialize_custom_shaders(

backends/vulkan/vulkan_preprocess.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,4 +229,5 @@ def preprocess( # noqa: C901
229229
vk_graph, graph_builder.const_tensors, []
230230
),
231231
debug_handle_map=graph_builder.delegate_mapping_builder.get_delegate_mapping(),
232+
data_store_output=graph_builder.named_data_store.get_named_data_store_output(),
232233
)

0 commit comments

Comments
 (0)