@@ -24,6 +24,13 @@ namespace vgf {
2424/* static function to map format to byte count */
2525static uint32_t get_format_size (VkFormat format);
2626
27+ // SPV_ARM_tensor does not support rank-0 representations according to the spec.
28+ // Use an unsqueezed dimension when the resource table contains an empty
29+ // shape. Tensors are output as rank 0 when copied back from the vgf backend.
30+ namespace {
31+ constexpr int64_t kScalarSentinelDimension = 1 ;
32+ }
33+
2734// Debug function to inspect memory properties
2835static string memory_flags_to_string (VkMemoryPropertyFlags flags) {
2936 if (flags == 0 )
@@ -264,7 +271,11 @@ static void debug_print_resources(
264271 the_shape.size (),
265272 the_stride.size ());
266273 for (int j = 0 ; j < the_shape.size (); j++) {
267- ET_LOG (Info, " %d: dim %ld" , j, the_shape[j]);
274+ ET_LOG (
275+ Info,
276+ " %d: dim %lld" ,
277+ j,
278+ static_cast <long long >(the_shape[j]));
268279 }
269280 // Allocate a tensor with bound memory
270281 break ;
@@ -387,6 +398,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
387398 // Get tensor shape and strides
388399 auto shape = resource_decoder->getTensorShape (i);
389400 auto stride = resource_decoder->getTensorStride (i);
401+ const auto shape_size = shape.size ();
390402
391403 switch (resource_decoder->getCategory (i)) {
392404 case vgflib::ResourceCategory::INPUT:
@@ -409,9 +421,9 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
409421 result = allocate_tensor (
410422 vk_physical,
411423 vk_device,
412- vgflib::ToVkFormat (resource_decoder-> getVkFormat (i)) ,
413- static_cast <uint32_t >(shape. size () ),
414- shape.begin (),
424+ resource_format ,
425+ shape_size == 0 ? 1 : static_cast <uint32_t >(shape_size ),
426+ shape_size == 0 ? & kScalarSentinelDimension : shape.begin (),
415427 static_cast <uint32_t >(stride.size ()),
416428 stride.begin (),
417429 &tensor_description,
@@ -422,8 +434,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
422434 ET_LOG (Error, " Failed to allocate tensor for VGF resource %d" , i);
423435 return false ;
424436 }
425- size_t e_size = get_format_size (
426- vgflib::ToVkFormat (resource_decoder->getVkFormat (i)));
437+ size_t e_size = get_format_size (resource_format);
427438 if (0 == e_size) {
428439 ET_LOG (Error, " failed to get element size of VkFormat" );
429440 return false ;
@@ -449,9 +460,11 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
449460 .sType = VK_STRUCTURE_TYPE_TENSOR_DESCRIPTION_ARM,
450461 .pNext = nullptr ,
451462 .tiling = VK_TENSOR_TILING_LINEAR_ARM,
452- .format = vgflib::ToVkFormat (resource_decoder->getVkFormat (i)),
453- .dimensionCount = static_cast <uint32_t >(shape.size ()),
454- .pDimensions = shape.begin (),
463+ .format = resource_format,
464+ .dimensionCount =
465+ shape_size == 0 ? 1 : static_cast <uint32_t >(shape_size),
466+ .pDimensions =
467+ shape_size == 0 ? &kScalarSentinelDimension : shape.begin (),
455468 // Note: stride_data of 0's causes size==0, null means stride==size
456469 .pStrides = (0 == stride.size () ? nullptr : stride.begin ()),
457470 .usage = VK_TENSOR_USAGE_DATA_GRAPH_BIT_ARM,
0 commit comments