Skip to content

Commit 81d176c

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK][ez] Handle zero-element tensors when building Vulkan graph (#12021)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11984 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/250/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/250/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/249/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/250/orig @diff-train-skip-merge cc @SS-JIA @manuelcandales @cbilgin --------- Co-authored-by: Stephen Jia <[email protected]>
1 parent ab4217e commit 81d176c

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

backends/vulkan/serialization/vulkan_graph_serialize.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -191,19 +191,23 @@ def serialize_constant_tensors(
191191

192192
current_offset = len(raw_bytes)
193193
for tensor in const_tensors:
194-
array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
195-
array = ctypes.cast(
196-
tensor.untyped_storage().data_ptr(),
197-
ctypes.POINTER(array_type),
198-
).contents
199-
200-
tensor_bytes = bytes(array)
201-
# Pad the tensor bytes to the next 16 byte boundary
202-
raw_bytes += tensor_bytes
203-
raw_bytes += b"\x00" * padding_required(len(tensor_bytes))
204-
205-
vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes)))
206-
current_offset += aligned_size(len(tensor_bytes))
194+
if tensor.numel() == 0:
195+
vk_graph.constants.append(VkBytes(current_offset, 0))
196+
continue
197+
else:
198+
array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
199+
array = ctypes.cast(
200+
tensor.untyped_storage().data_ptr(),
201+
ctypes.POINTER(array_type),
202+
).contents
203+
204+
tensor_bytes = bytes(array)
205+
# Pad the tensor bytes to the next 16 byte boundary
206+
raw_bytes += tensor_bytes
207+
raw_bytes += b"\x00" * padding_required(len(tensor_bytes))
208+
209+
vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes)))
210+
current_offset += aligned_size(len(tensor_bytes))
207211

208212

209213
def serialize_custom_shaders(

0 commit comments

Comments
 (0)