diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index cb14e96962d..3c31e0316a6 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -83,6 +83,10 @@ def op_node_is_compatible( return False, "no operator implementation" features = get_op_features(target) + # Check for high dimensional tensors + if utils.is_tensor_node(node) and utils.tensor_node_is_high_dim(node): + return False, "contains high dim tensor" + valid_texture_layouts = utils.possible_node_memory_layouts( node, self.texture_limits ) @@ -94,6 +98,10 @@ def op_node_is_compatible( and utils.is_tensor_node(arg) and i not in features.skip_limits_check ): + # Check for high dimensional tensors + if utils.is_tensor_node(arg) and utils.tensor_node_is_high_dim(arg): + return False, "contains high dim tensor" + arg_texture_layouts = utils.possible_node_memory_layouts( arg, self.texture_limits ) diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 8a2701a5c02..129f40df8b1 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -97,6 +97,7 @@ def lower_module_and_test_output( dynamic_shapes=None, test_inputs=None, first_output_only=False, + expect_no_delegates=False, ): """ Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with @@ -125,10 +126,23 @@ def run_test(): ) executorch_program = edge_program.to_executorch() - self.assertEqual( - executorch_program.executorch_program.execution_plan[0].delegates[0].id, - VulkanBackend.__name__, - ) + if expect_no_delegates: + self.assertEqual( + len( + executorch_program.executorch_program.execution_plan[ + 0 + ].delegates + ), + 0, + ) + return + else: + self.assertEqual( + executorch_program.executorch_program.execution_plan[0] + .delegates[0] + .id, + VulkanBackend.__name__, + ) executorch_module = _load_for_executorch_from_buffer( executorch_program.buffer @@ -1683,3 +1697,17 @@ def forward(self, x): GridPriorsModule(), (torch.rand(size=[1, 5, 2, 3]),), ) + + def test_vulkan_backend_high_dim_tensors_fail(self): + class UnsqueezeHigherDim(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.unsqueeze(x, 2) + + self.lower_module_and_test_output( + UnsqueezeHigherDim(), + (torch.ones(size=[5, 4, 1, 2, 6]),), + expect_no_delegates=True, + ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 1a030e5e8f5..5034747be9d 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -130,6 +130,20 @@ def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int: raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}") +def tensor_node_is_high_dim(node: torch.fx.Node) -> bool: + """ + Returns true if a given node contains a tensor with more than 4 dimensions + """ + if isinstance(node.meta["val"], FakeTensor): + return len(node.meta["val"].shape) > 4 + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + for fake_tensor in node.meta["val"]: + if isinstance(fake_tensor, FakeTensor): + if len(fake_tensor.shape) > 4: + return True + return False + + def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents: """ Calculate the image extents that will be used to represent a tensor with the given sizes