diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 03721066f1c..8983ca67752 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -6,7 +6,7 @@ import logging from copy import deepcopy -from typing import Any, Set +from typing import Any, Optional, Set import executorch.backends.vulkan.utils as utils @@ -94,7 +94,7 @@ def __init__( def propose_node_storage( self, node: torch.fx.Node, - ) -> VkStorageType: + ) -> Optional[VkStorageType]: """ Uses the operator registry to determine the storage type that should be used for a given node. The storage type is determined with the following priorities: @@ -114,6 +114,9 @@ def propose_node_storage( opinionated user can be found, then proceed to the last step. 4. Use the default storage type setting. """ + if not utils.is_tensor_node(node): + return None + # The node may have an input/output tensor that is too big to be stored in a # texture. In this case, buffer storage must be used. Note that the partitioner # has already checked for the fact that buffer storage is supported by the @@ -154,12 +157,15 @@ def propose_node_layout( self, node: torch.fx.Node, storage: VkStorageType, - ) -> VkMemoryLayout: + ) -> Optional[VkMemoryLayout]: """ Performs the same steps as propose_node_storage, but detects the memory layout that should be used for the specific storage type. The same prioritization logic is applied. """ + if not utils.is_tensor_node(node): + return None + valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts # pyre-ignore if has_impl(node.target): diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 8502e254ec5..0486110ced6 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -228,6 +228,8 @@ def update_features_impl(op: OpKey): exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + # Symbolic integer ops + torch.ops.aten.sym_size.int, ] ) def register_ephemeral_op(features: OpFeatures): @@ -505,6 +507,7 @@ def register_sdpa_ops(features: OpFeatures): features.texture_impl = TextureImplFeatures( valid_packed_dims={PackedDim.WIDTH}, ) + features.resize_fn = True return features diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 07660c8878f..d690e886d40 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -146,9 +146,8 @@ def op_node_is_compatible( # noqa: C901: Function is too complex def node_is_compatible( self, node: torch.fx.Node, features: Optional[OpFeatures] = None ) -> Tuple[bool, str]: - # TODO(ssjia) support symbolic ints if utils.is_symint_node(node): - return False, "symint node not supported yet" + return node.target in vulkan_supported_ops, "Op is compatible" elif utils.is_tensor_node(node): return self.op_node_is_compatible(node, features=features) @@ -258,7 +257,7 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: if target not in vulkan_supported_ops: # For some ops, i.e. custom ops the name is registered instead of the # OpOverload object. - if not isinstance(target, str) and target.name() in vulkan_supported_ops: + if hasattr(target, "name") and target.name() in vulkan_supported_ops: features = vulkan_supported_ops[target.name()] else: self.log_skip(node, "no operator implementation") diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 59fd561a2c5..709e9fa1f12 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -156,6 +156,38 @@ ComputeGraph::~ComputeGraph() { context_->flush(); } +std::vector ComputeGraph::extract_int_or_symint_list( + const ValueRef idx) { + const Value& val = values_.at(idx); + std::vector result; + + if (val.isIntList()) { + // If it's an IntList, return a copy of the list + return val.toConstIntList(); + } else if (val.isValueList()) { + // If it's a ValueList, extract each element as an Int or SymInt + const std::vector& value_list = val.toConstValueList(); + result.reserve(value_list.size()); + + for (const ValueRef& ref : value_list) { + const Value& element = values_.at(ref); + if (element.isInt()) { + result.push_back(element.toInt()); + } else if (element.isSymInt()) { + result.push_back(read_symint(ref)); + } else { + VK_THROW( + "ValueList element is neither Int nor SymInt, but has type ", + element.type()); + } + } + return result; + } + + VK_THROW( + "Cannot extract int or symint list from Value with type ", val.type()); +} + utils::StorageType ComputeGraph::suggested_storage_type() { if (config_.enable_storage_type_override) { return config_.storage_type_override; diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index d09597ad778..d45666c6bbf 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -405,6 +405,15 @@ class ComputeGraph final { return values_.at(idx).toString(); } + /* + * Utility function to extract a list of integers from a ValueRef. + * If the ValueRef is an IntList, returns a copy of the list. + * If the ValueRef is a ValueList, extracts each element as an Int or SymInt + * and returns the resulting list. + * Throws an error if the ValueRef is neither an IntList nor a ValueList. + */ + std::vector extract_int_or_symint_list(const ValueRef idx); + template < typename T, typename std::enable_if< diff --git a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp new file mode 100644 index 00000000000..0705b53d394 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace vkcompute { + +void resize_sym_size_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)args; // Unused parameter + + ValueRef out_symint_ref = extra_args[0]; + ValueRef in_tensor_ref = extra_args[1]; + + int64_t dim = graph->extract_scalar(extra_args[2]); + int64_t size_at_dim = graph->size_at(dim, in_tensor_ref); + + graph->set_symint(out_symint_ref, static_cast(size_at_dim)); +} + +/* + * This operator takes a tensor and an integer dimension as inputs, and produces + * a symint as output. The symint's value is the size of the tensor at the + * specified dimension. + */ +void sym_size_int(ComputeGraph& graph, const std::vector& args) { + ValueRef in_tensor = args[0]; + ValueRef dim = args[1]; + ValueRef out_symint = args[2]; + + int64_t dim_val = graph.extract_scalar(dim); + + int64_t size_at_dim = graph.size_at(dim_val, in_tensor); + graph.set_symint(out_symint, static_cast(size_at_dim)); + + graph.execute_nodes().emplace_back( + new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim})); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(sym_size.int, sym_size_int); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index ef71f8d6d29..710ba0d576f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -48,9 +48,9 @@ void resize_view_node( if (extra_args[0] == kDummyValueRef || graph->val_is_none(extra_args[0])) { out->virtual_resize(in->sizes()); } else { - IntListPtr view_sizes = graph->get_int_list(extra_args[0]); - std::vector out_sizes = - compute_out_sizes(in->sizes(), *view_sizes); + std::vector view_sizes = + graph->extract_int_or_symint_list(extra_args[0]); + std::vector out_sizes = compute_out_sizes(in->sizes(), view_sizes); out->virtual_resize(out_sizes); } } diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index d01c8b53b35..76b9240463b 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -21,6 +21,7 @@ is_constant, is_get_attr_node, is_param_node, + is_symint_node, ) from executorch.exir.backend.utils import DelegateMappingBuilder @@ -54,6 +55,8 @@ def __init__( # Mapping from Node to VkValue id self.node_to_value_ids = {} + # Mapping from const scalar value to created VkValue id + self.const_scalar_to_value_ids = {} # For logging self.seen_ops = set() @@ -128,7 +131,7 @@ def maybe_add_constant_tensor(self, node: Node) -> int: def create_node_value(self, node: Node) -> int: # If the node has been marked as a scalar tensor, create a SymInt instead of a tensor - if node.meta.get("vkdg_is_scalar_tensor", False): + if is_symint_node(node) or node.meta.get("vkdg_is_scalar_tensor", False): new_id = self.create_symint_value() self.node_to_value_ids[node] = new_id return new_id @@ -146,14 +149,26 @@ def create_node_value(self, node: Node) -> int: self.node_to_value_ids[node] = new_id return new_id else: - raise RuntimeError(f"Cannot create value for spec of type {type(spec)}") + raise RuntimeError( + f"Cannot create value for node {node} with spec of type {type(spec)}" + ) def create_null_value(self) -> int: new_id = len(self.values) self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Null())) return new_id - def create_scalar_value(self, scalar: _ScalarType) -> int: + def get_or_create_scalar_value(self, scalar: _ScalarType) -> int: + scalar_key = scalar + # Since Python considers 1 and True to be "equivalent" (as well as 0 and False) + # to distinguish entries in the dictionary, if scalar is bool then convert it + # to a string representation to use as a key for the dictionary + if isinstance(scalar, bool): + scalar_key = str(scalar) + + if scalar_key in self.const_scalar_to_value_ids: + return self.const_scalar_to_value_ids[scalar_key] + new_id = len(self.values) if isinstance(scalar, bool): self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar))) @@ -161,6 +176,8 @@ def create_scalar_value(self, scalar: _ScalarType) -> int: self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar))) elif isinstance(scalar, float): self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar))) + + self.const_scalar_to_value_ids[scalar_key] = new_id return new_id def create_symint_value(self) -> int: @@ -200,28 +217,50 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: def create_scalar_list_value(self, arg: List[_ScalarType]) -> int: new_id = len(self.values) + if len(arg) == 0: self.values.append( vk_graph_schema.VkValue(vk_graph_schema.IntList(items=[])) ) - elif isinstance(arg[0], bool): + + all_bool = True + all_int = True + all_float = True + all_int_or_symint = True + + for val in arg: + if not isinstance(val, bool): + all_bool = False + if not isinstance(val, int): + all_int = False + if not (isinstance(val, Node) and is_symint_node(val)): + all_int_or_symint = False + if not isinstance(val, float): + all_float = False + + if all_bool: self.values.append( vk_graph_schema.VkValue( vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg]) ) ) - elif isinstance(arg[0], int): + if all_int: self.values.append( vk_graph_schema.VkValue( vk_graph_schema.IntList(items=[cast(int, e) for e in arg]) ) ) - elif isinstance(arg[0], float): + elif all_float: self.values.append( vk_graph_schema.VkValue( vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg]) ) ) + elif all_int_or_symint: + return self.create_value_list_value(arg) + else: + raise NotImplementedError(f"Cannot add value for list {arg}") + return new_id def create_value_list_value(self, arg: tuple | list) -> int: @@ -256,11 +295,11 @@ def get_or_create_value_for(self, arg: _Argument): ): return self.create_null_value() elif isinstance(arg, _ScalarType): - return self.create_scalar_value(arg) + return self.get_or_create_scalar_value(arg) elif isinstance(arg, TensorSpec): return self.create_tensor_value(arg) elif isinstance(arg, list) and ( - len(arg) == 0 or isinstance(arg[0], _ScalarType) + len(arg) == 0 or any(isinstance(val, _ScalarType) for val in arg) ): # pyre-ignore[6] return self.create_scalar_list_value(arg) diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index b57710974e8..5833a0cdb72 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1801,3 +1801,44 @@ def forward(self, x: torch.Tensor): LinearModel(n_pca_basis, n_sh_basis, n_gaussians), (torch.ones(n_pca_basis),), ) + + def test_vulkan_backend_sym_size_int(self): + """ + Test the sym_size.int operator with a model that: + 1. Takes an input tensor with shape [1, M, K] + 2. Reshapes it to [M, K] + 3. Applies a linear layer + 4. Reshapes the output back to [1, M, N] + """ + K = 64 # Input feature dimension + N = 32 # Output feature dimension + + class SymSizeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(K, N) + + def forward(self, x): + M = x.size(1) + + reshaped = torch.reshape(x, [M, K]) + output = self.linear(reshaped) + return torch.reshape(output, [1, M, N]) + + sample_inputs = (torch.randn(1, 64, K),) + + batch = Dim("batch", min=1, max=128) + dynamic_shapes = {"x": {1: batch}} + + test_inputs = [ + (torch.randn(1, 32, K),), + (torch.randn(1, 96, K),), + (torch.randn(1, 128, K),), + ] + + self.lower_module_and_test_output( + SymSizeModel(), + sample_inputs, + dynamic_shapes=dynamic_shapes, + test_inputs=test_inputs, + ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index eb949a6ace8..642f7c5f495 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -101,10 +101,6 @@ def is_tensor_node(node: torch.fx.Node) -> bool: """ Returns true if the given node produces a tensor value, or a collection of tensor values """ - # All nodes with tensor values are tagged by the SpecPropPass transform - if "spec" in node.meta: - return True - if "val" not in node.meta: return False