Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def op_node_is_compatible(
return False, "no operator implementation"
features = get_op_features(target)

# Check for high dimensional tensors
if utils.tensor_node_is_high_dim(node):
return False, "contains high dim tensor"

valid_texture_layouts = utils.possible_node_memory_layouts(
node, self.texture_limits
)
Expand All @@ -94,6 +98,10 @@ def op_node_is_compatible(
and utils.is_tensor_node(arg)
and i not in features.skip_limits_check
):
# Check for high dimensional tensors
if utils.tensor_node_is_high_dim(arg):
return False, "contains high dim tensor"

arg_texture_layouts = utils.possible_node_memory_layouts(
arg, self.texture_limits
)
Expand Down
36 changes: 32 additions & 4 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def lower_module_and_test_output(
dynamic_shapes=None,
test_inputs=None,
first_output_only=False,
expect_no_delegates=False,
):
"""
Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with
Expand Down Expand Up @@ -125,10 +126,23 @@ def run_test():
)
executorch_program = edge_program.to_executorch()

self.assertEqual(
executorch_program.executorch_program.execution_plan[0].delegates[0].id,
VulkanBackend.__name__,
)
if expect_no_delegates:
self.assertEqual(
len(
executorch_program.executorch_program.execution_plan[
0
].delegates
),
0,
)
return
else:
self.assertEqual(
executorch_program.executorch_program.execution_plan[0]
.delegates[0]
.id,
VulkanBackend.__name__,
)

executorch_module = _load_for_executorch_from_buffer(
executorch_program.buffer
Expand Down Expand Up @@ -1683,3 +1697,17 @@ def forward(self, x):
GridPriorsModule(),
(torch.rand(size=[1, 5, 2, 3]),),
)

def test_vulkan_backend_high_dim_tensors_fail(self):
class UnsqueezeHigherDim(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.unsqueeze(x, 2)

self.lower_module_and_test_output(
UnsqueezeHigherDim(),
(torch.ones(size=[5, 4, 1, 2, 6]),),
expect_no_delegates=True,
)
18 changes: 18 additions & 0 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,24 @@ def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")


def tensor_node_is_high_dim(node: torch.fx.Node) -> bool:
"""
If the node does not contain a tensor or a collection of tensors, return False.
Otherwise, return True if the tensor is high dimensional (i.e. rank > 4).
"""
if is_tensor_node(node):
if isinstance(node.meta["val"], FakeTensor):
return len(node.meta["val"].shape) > 4
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
for fake_tensor in node.meta["val"]:
if isinstance(fake_tensor, FakeTensor):
if len(fake_tensor.shape) > 4:
return True
return False
else:
return False


def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
"""
Calculate the image extents that will be used to represent a tensor with the given sizes
Expand Down
Loading