Skip to content

Commit 88eb18a

Browse files
committed
[ET-VK][ez] Handle zero-element tensors when building Vulkan graph
## Changes As title. ## Motivation Some models may have parameter tensors which are zero-shape (i.e. no elements). In this case, trying to serialize the tensor data will result in a null pointer exception. Differential Revision: [D77281492](https://our.internmc.facebook.com/intern/diff/D77281492/) ghstack-source-id: 292684344 Pull Request resolved: #11984
1 parent 951aef9 commit 88eb18a

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

backends/vulkan/serialization/vulkan_graph_serialize.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -191,19 +191,23 @@ def serialize_constant_tensors(
191191

192192
current_offset = len(raw_bytes)
193193
for tensor in const_tensors:
194-
array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
195-
array = ctypes.cast(
196-
tensor.untyped_storage().data_ptr(),
197-
ctypes.POINTER(array_type),
198-
).contents
199-
200-
tensor_bytes = bytes(array)
201-
# Pad the tensor bytes to the next 16 byte boundary
202-
raw_bytes += tensor_bytes
203-
raw_bytes += b"\x00" * padding_required(len(tensor_bytes))
204-
205-
vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes)))
206-
current_offset += aligned_size(len(tensor_bytes))
194+
if tensor.numel() == 0:
195+
vk_graph.constants.append(VkBytes(current_offset, 0))
196+
continue
197+
else:
198+
array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
199+
array = ctypes.cast(
200+
tensor.untyped_storage().data_ptr(),
201+
ctypes.POINTER(array_type),
202+
).contents
203+
204+
tensor_bytes = bytes(array)
205+
# Pad the tensor bytes to the next 16 byte boundary
206+
raw_bytes += tensor_bytes
207+
raw_bytes += b"\x00" * padding_required(len(tensor_bytes))
208+
209+
vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes)))
210+
current_offset += aligned_size(len(tensor_bytes))
207211

208212

209213
def serialize_custom_shaders(

0 commit comments

Comments
 (0)