Skip to content

Commit 1121cba

Browse files
pytorchbotArmRyan
andauthored
Arm backend: Unsqueeze rank 0 tensor at vgf runtime (#14934)
Rank 0 tensors are not supported in SPV_ARM_tensor. We need to symbolically unsqueeze scalar IOs at runtime. * Remove xfails related to MLETORCH-1410 Change-Id: I1cf46919dec422b15f51faf18d676c661df276a6 cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Ryan O'Shea <[email protected]> Co-authored-by: Ryan OShea <[email protected]>
1 parent c77f720 commit 1121cba

File tree

3 files changed

+22
-11
lines changed

3 files changed

+22
-11
lines changed

backends/arm/runtime/VGFSetup.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ namespace vgf {
2424
/* static function to map format to byte count */
2525
static uint32_t get_format_size(VkFormat format);
2626

27+
// SPV_ARM_tensor does not support rank-0 representations according to the spec.
28+
// Use an unsqueezed dimension when the resource table contains an empty
29+
// shape. Tensors are output as rank 0 when copied back from the vgf backend.
30+
namespace {
31+
constexpr int64_t kScalarSentinelDimension = 1;
32+
}
33+
2734
// Debug function to inspect memory properties
2835
static string memory_flags_to_string(VkMemoryPropertyFlags flags) {
2936
if (flags == 0)
@@ -264,7 +271,11 @@ static void debug_print_resources(
264271
the_shape.size(),
265272
the_stride.size());
266273
for (int j = 0; j < the_shape.size(); j++) {
267-
ET_LOG(Info, " %d: dim %ld", j, the_shape[j]);
274+
ET_LOG(
275+
Info,
276+
" %d: dim %lld",
277+
j,
278+
static_cast<long long>(the_shape[j]));
268279
}
269280
// Allocate a tensor with bound memory
270281
break;
@@ -387,6 +398,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
387398
// Get tensor shape and strides
388399
auto shape = resource_decoder->getTensorShape(i);
389400
auto stride = resource_decoder->getTensorStride(i);
401+
const auto shape_size = shape.size();
390402

391403
switch (resource_decoder->getCategory(i)) {
392404
case vgflib::ResourceCategory::INPUT:
@@ -409,9 +421,9 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
409421
result = allocate_tensor(
410422
vk_physical,
411423
vk_device,
412-
vgflib::ToVkFormat(resource_decoder->getVkFormat(i)),
413-
static_cast<uint32_t>(shape.size()),
414-
shape.begin(),
424+
resource_format,
425+
shape_size == 0 ? 1 : static_cast<uint32_t>(shape_size),
426+
shape_size == 0 ? &kScalarSentinelDimension : shape.begin(),
415427
static_cast<uint32_t>(stride.size()),
416428
stride.begin(),
417429
&tensor_description,
@@ -422,8 +434,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
422434
ET_LOG(Error, "Failed to allocate tensor for VGF resource %d", i);
423435
return false;
424436
}
425-
size_t e_size = get_format_size(
426-
vgflib::ToVkFormat(resource_decoder->getVkFormat(i)));
437+
size_t e_size = get_format_size(resource_format);
427438
if (0 == e_size) {
428439
ET_LOG(Error, "failed to get element size of VkFormat");
429440
return false;
@@ -449,9 +460,11 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
449460
.sType = VK_STRUCTURE_TYPE_TENSOR_DESCRIPTION_ARM,
450461
.pNext = nullptr,
451462
.tiling = VK_TENSOR_TILING_LINEAR_ARM,
452-
.format = vgflib::ToVkFormat(resource_decoder->getVkFormat(i)),
453-
.dimensionCount = static_cast<uint32_t>(shape.size()),
454-
.pDimensions = shape.begin(),
463+
.format = resource_format,
464+
.dimensionCount =
465+
shape_size == 0 ? 1 : static_cast<uint32_t>(shape_size),
466+
.pDimensions =
467+
shape_size == 0 ? &kScalarSentinelDimension : shape.begin(),
455468
// Note: stride_data of 0's causes size==0, null means stride==size
456469
.pStrides = (0 == stride.size() ? nullptr : stride.begin()),
457470
.usage = VK_TENSOR_USAGE_DATA_GRAPH_BIT_ARM,

backends/arm/test/ops/test_mean_dim.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
7-
87
import torch
98
from executorch.backends.arm.test import common
109
from executorch.backends.arm.test.tester.test_pipeline import (

backends/arm/test/ops/test_scalar_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
65
import torch
76
from executorch.backends.arm.test import common
87

0 commit comments

Comments
 (0)