Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import ctypes
import hashlib
import logging
import operator
from types import NoneType
Expand All @@ -25,6 +27,7 @@
is_symint_node,
TensorRepr,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir.backend.utils import DelegateMappingBuilder

from executorch.exir.tensor import TensorSpec
Expand Down Expand Up @@ -56,6 +59,7 @@ def __init__(
self.input_ids = []
self.output_ids = []
self.const_tensors = []
self.named_data_store = NamedDataStore()

# Mapping from Node to VkValue id
self.node_to_value_ids = {}
Expand Down Expand Up @@ -129,8 +133,36 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
def maybe_add_constant_tensor(self, node: Node) -> int:
constant_id = -1
if is_param_node(self.program, node):
constant_id = len(self.const_tensors)
self.const_tensors.append(self.get_param_tensor(node))
tensor = self.get_param_tensor(node)

# Serialize tensor data to bytes
tensor = tensor.contiguous()
size = tensor.untyped_storage().nbytes()

if size > 0:
array_type = ctypes.c_char * size
array = ctypes.cast(
tensor.untyped_storage().data_ptr(),
ctypes.POINTER(array_type),
).contents

# Generate SHA256 hash as the named key
tensor_bytes = bytes(array)
sha256_hash = hashlib.sha256(tensor_bytes)
named_key = sha256_hash.hexdigest()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we deduplicate the data in the pte using the value, and not the key right?


# Add to named data store with 16-byte alignment (matching XNNPACK)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

self.named_data_store.add_named_data(
named_key, tensor_bytes, alignment=16
)

# Create VkBytes entry with named_key and set offset to indicate named data usage
constant_id = len(self.const_tensors)
self.const_tensors.append((named_key, size))
else:
# Handle empty tensors
constant_id = len(self.const_tensors)
self.const_tensors.append(None)

return constant_id

Expand Down
15 changes: 13 additions & 2 deletions backends/vulkan/serialization/vulkan_graph_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,21 @@ def serialize_constant_tensors(

current_offset = len(raw_bytes)
for tensor in const_tensors:
if tensor.numel() == 0:
# The tensor data is stored in the named data map
if isinstance(tensor, tuple):
named_key, size = tensor
vk_graph.constants.append(
VkBytes(
offset=18446744073709551615, # UINT64_MAX to indicate named data
length=size,
named_key=named_key,
)
)
elif tensor is None or tensor.numel() == 0:
assert isinstance(tensor, torch.Tensor)
vk_graph.constants.append(VkBytes(current_offset, 0))
continue
else:
assert isinstance(tensor, torch.Tensor)
array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
array = ctypes.cast(
tensor.untyped_storage().data_ptr(),
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,5 @@ def preprocess( # noqa: C901
vk_graph, graph_builder.const_tensors, []
),
debug_handle_map=graph_builder.delegate_mapping_builder.get_delegate_mapping(),
data_store_output=graph_builder.named_data_store.get_named_data_store_output(),
)
Loading