From ecbdd6b9e52e91350515a741a43ea9d9027e95f5 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 19 May 2025 13:27:05 -0700 Subject: [PATCH] [ET-VK] Support exporting graphs with symbolic shape ops + update view to accept sym_size args ## Context The ultimate goal is to be able to export the transformer models with dynamic shapes enabled so that batched prefill can be done. With transformer models, when dynamic shapes are turned on, `sym_size` operators appear in the graph which are used to determine the `seq_len` of the inputs, i.e. how many tokens are being passed into the input sequence. The `sym_size` operator accepts a tensor and a dim, and extracts the size of the tensor at the specified dim as a symbolic integer. In the transformer model, the `seq_len` symint is used as an argument to `view` operators. This PR enables exporting graphs with symbolic integer nodes and in particular the `sym_size` operator, as well as handling when symints are used in a list of ints. # Changes * Miscellaneous fixes to fix errors that show occur when symint nodes are encountered * Add C++ implementation of symint nodes and add registration for it * Enable the view operator to handle when the sizes arg includes symints Differential Revision: [D75019798](https://our.internmc.facebook.com/intern/diff/D75019798/) [ghstack-poisoned] --- .../vulkan/_passes/tag_memory_meta_pass.py | 12 +++- backends/vulkan/op_registry.py | 3 + .../vulkan/partitioner/vulkan_partitioner.py | 5 +- .../vulkan/runtime/graph/ComputeGraph.cpp | 32 +++++++++++ backends/vulkan/runtime/graph/ComputeGraph.h | 9 +++ .../runtime/graph/ops/impl/SymIntOps.cpp | 52 ++++++++++++++++++ .../vulkan/runtime/graph/ops/impl/View.cpp | 6 +- .../serialization/vulkan_graph_builder.py | 55 ++++++++++++++++--- backends/vulkan/test/test_vulkan_delegate.py | 41 ++++++++++++++ backends/vulkan/utils.py | 4 -- 10 files changed, 198 insertions(+), 21 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 03721066f1c..8983ca67752 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -6,7 +6,7 @@ import logging from copy import deepcopy -from typing import Any, Set +from typing import Any, Optional, Set import executorch.backends.vulkan.utils as utils @@ -94,7 +94,7 @@ def __init__( def propose_node_storage( self, node: torch.fx.Node, - ) -> VkStorageType: + ) -> Optional[VkStorageType]: """ Uses the operator registry to determine the storage type that should be used for a given node. The storage type is determined with the following priorities: @@ -114,6 +114,9 @@ def propose_node_storage( opinionated user can be found, then proceed to the last step. 4. Use the default storage type setting. """ + if not utils.is_tensor_node(node): + return None + # The node may have an input/output tensor that is too big to be stored in a # texture. In this case, buffer storage must be used. Note that the partitioner # has already checked for the fact that buffer storage is supported by the @@ -154,12 +157,15 @@ def propose_node_layout( self, node: torch.fx.Node, storage: VkStorageType, - ) -> VkMemoryLayout: + ) -> Optional[VkMemoryLayout]: """ Performs the same steps as propose_node_storage, but detects the memory layout that should be used for the specific storage type. The same prioritization logic is applied. """ + if not utils.is_tensor_node(node): + return None + valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts # pyre-ignore if has_impl(node.target): diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 8502e254ec5..0486110ced6 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -228,6 +228,8 @@ def update_features_impl(op: OpKey): exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + # Symbolic integer ops + torch.ops.aten.sym_size.int, ] ) def register_ephemeral_op(features: OpFeatures): @@ -505,6 +507,7 @@ def register_sdpa_ops(features: OpFeatures): features.texture_impl = TextureImplFeatures( valid_packed_dims={PackedDim.WIDTH}, ) + features.resize_fn = True return features diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 07660c8878f..d690e886d40 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -146,9 +146,8 @@ def op_node_is_compatible( # noqa: C901: Function is too complex def node_is_compatible( self, node: torch.fx.Node, features: Optional[OpFeatures] = None ) -> Tuple[bool, str]: - # TODO(ssjia) support symbolic ints if utils.is_symint_node(node): - return False, "symint node not supported yet" + return node.target in vulkan_supported_ops, "Op is compatible" elif utils.is_tensor_node(node): return self.op_node_is_compatible(node, features=features) @@ -258,7 +257,7 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: if target not in vulkan_supported_ops: # For some ops, i.e. custom ops the name is registered instead of the # OpOverload object. - if not isinstance(target, str) and target.name() in vulkan_supported_ops: + if hasattr(target, "name") and target.name() in vulkan_supported_ops: features = vulkan_supported_ops[target.name()] else: self.log_skip(node, "no operator implementation") diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 59fd561a2c5..709e9fa1f12 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -156,6 +156,38 @@ ComputeGraph::~ComputeGraph() { context_->flush(); } +std::vector ComputeGraph::extract_int_or_symint_list( + const ValueRef idx) { + const Value& val = values_.at(idx); + std::vector result; + + if (val.isIntList()) { + // If it's an IntList, return a copy of the list + return val.toConstIntList(); + } else if (val.isValueList()) { + // If it's a ValueList, extract each element as an Int or SymInt + const std::vector& value_list = val.toConstValueList(); + result.reserve(value_list.size()); + + for (const ValueRef& ref : value_list) { + const Value& element = values_.at(ref); + if (element.isInt()) { + result.push_back(element.toInt()); + } else if (element.isSymInt()) { + result.push_back(read_symint(ref)); + } else { + VK_THROW( + "ValueList element is neither Int nor SymInt, but has type ", + element.type()); + } + } + return result; + } + + VK_THROW( + "Cannot extract int or symint list from Value with type ", val.type()); +} + utils::StorageType ComputeGraph::suggested_storage_type() { if (config_.enable_storage_type_override) { return config_.storage_type_override; diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index d09597ad778..d45666c6bbf 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -405,6 +405,15 @@ class ComputeGraph final { return values_.at(idx).toString(); } + /* + * Utility function to extract a list of integers from a ValueRef. + * If the ValueRef is an IntList, returns a copy of the list. + * If the ValueRef is a ValueList, extracts each element as an Int or SymInt + * and returns the resulting list. + * Throws an error if the ValueRef is neither an IntList nor a ValueList. + */ + std::vector extract_int_or_symint_list(const ValueRef idx); + template < typename T, typename std::enable_if< diff --git a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp new file mode 100644 index 00000000000..0705b53d394 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace vkcompute { + +void resize_sym_size_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)args; // Unused parameter + + ValueRef out_symint_ref = extra_args[0]; + ValueRef in_tensor_ref = extra_args[1]; + + int64_t dim = graph->extract_scalar(extra_args[2]); + int64_t size_at_dim = graph->size_at(dim, in_tensor_ref); + + graph->set_symint(out_symint_ref, static_cast(size_at_dim)); +} + +/* + * This operator takes a tensor and an integer dimension as inputs, and produces + * a symint as output. The symint's value is the size of the tensor at the + * specified dimension. + */ +void sym_size_int(ComputeGraph& graph, const std::vector& args) { + ValueRef in_tensor = args[0]; + ValueRef dim = args[1]; + ValueRef out_symint = args[2]; + + int64_t dim_val = graph.extract_scalar(dim); + + int64_t size_at_dim = graph.size_at(dim_val, in_tensor); + graph.set_symint(out_symint, static_cast(size_at_dim)); + + graph.execute_nodes().emplace_back( + new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim})); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(sym_size.int, sym_size_int); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index ef71f8d6d29..710ba0d576f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -48,9 +48,9 @@ void resize_view_node( if (extra_args[0] == kDummyValueRef || graph->val_is_none(extra_args[0])) { out->virtual_resize(in->sizes()); } else { - IntListPtr view_sizes = graph->get_int_list(extra_args[0]); - std::vector out_sizes = - compute_out_sizes(in->sizes(), *view_sizes); + std::vector view_sizes = + graph->extract_int_or_symint_list(extra_args[0]); + std::vector out_sizes = compute_out_sizes(in->sizes(), view_sizes); out->virtual_resize(out_sizes); } } diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index d01c8b53b35..76b9240463b 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -21,6 +21,7 @@ is_constant, is_get_attr_node, is_param_node, + is_symint_node, ) from executorch.exir.backend.utils import DelegateMappingBuilder @@ -54,6 +55,8 @@ def __init__( # Mapping from Node to VkValue id self.node_to_value_ids = {} + # Mapping from const scalar value to created VkValue id + self.const_scalar_to_value_ids = {} # For logging self.seen_ops = set() @@ -128,7 +131,7 @@ def maybe_add_constant_tensor(self, node: Node) -> int: def create_node_value(self, node: Node) -> int: # If the node has been marked as a scalar tensor, create a SymInt instead of a tensor - if node.meta.get("vkdg_is_scalar_tensor", False): + if is_symint_node(node) or node.meta.get("vkdg_is_scalar_tensor", False): new_id = self.create_symint_value() self.node_to_value_ids[node] = new_id return new_id @@ -146,14 +149,26 @@ def create_node_value(self, node: Node) -> int: self.node_to_value_ids[node] = new_id return new_id else: - raise RuntimeError(f"Cannot create value for spec of type {type(spec)}") + raise RuntimeError( + f"Cannot create value for node {node} with spec of type {type(spec)}" + ) def create_null_value(self) -> int: new_id = len(self.values) self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Null())) return new_id - def create_scalar_value(self, scalar: _ScalarType) -> int: + def get_or_create_scalar_value(self, scalar: _ScalarType) -> int: + scalar_key = scalar + # Since Python considers 1 and True to be "equivalent" (as well as 0 and False) + # to distinguish entries in the dictionary, if scalar is bool then convert it + # to a string representation to use as a key for the dictionary + if isinstance(scalar, bool): + scalar_key = str(scalar) + + if scalar_key in self.const_scalar_to_value_ids: + return self.const_scalar_to_value_ids[scalar_key] + new_id = len(self.values) if isinstance(scalar, bool): self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar))) @@ -161,6 +176,8 @@ def create_scalar_value(self, scalar: _ScalarType) -> int: self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar))) elif isinstance(scalar, float): self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar))) + + self.const_scalar_to_value_ids[scalar_key] = new_id return new_id def create_symint_value(self) -> int: @@ -200,28 +217,50 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: def create_scalar_list_value(self, arg: List[_ScalarType]) -> int: new_id = len(self.values) + if len(arg) == 0: self.values.append( vk_graph_schema.VkValue(vk_graph_schema.IntList(items=[])) ) - elif isinstance(arg[0], bool): + + all_bool = True + all_int = True + all_float = True + all_int_or_symint = True + + for val in arg: + if not isinstance(val, bool): + all_bool = False + if not isinstance(val, int): + all_int = False + if not (isinstance(val, Node) and is_symint_node(val)): + all_int_or_symint = False + if not isinstance(val, float): + all_float = False + + if all_bool: self.values.append( vk_graph_schema.VkValue( vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg]) ) ) - elif isinstance(arg[0], int): + if all_int: self.values.append( vk_graph_schema.VkValue( vk_graph_schema.IntList(items=[cast(int, e) for e in arg]) ) ) - elif isinstance(arg[0], float): + elif all_float: self.values.append( vk_graph_schema.VkValue( vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg]) ) ) + elif all_int_or_symint: + return self.create_value_list_value(arg) + else: + raise NotImplementedError(f"Cannot add value for list {arg}") + return new_id def create_value_list_value(self, arg: tuple | list) -> int: @@ -256,11 +295,11 @@ def get_or_create_value_for(self, arg: _Argument): ): return self.create_null_value() elif isinstance(arg, _ScalarType): - return self.create_scalar_value(arg) + return self.get_or_create_scalar_value(arg) elif isinstance(arg, TensorSpec): return self.create_tensor_value(arg) elif isinstance(arg, list) and ( - len(arg) == 0 or isinstance(arg[0], _ScalarType) + len(arg) == 0 or any(isinstance(val, _ScalarType) for val in arg) ): # pyre-ignore[6] return self.create_scalar_list_value(arg) diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index b57710974e8..5833a0cdb72 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1801,3 +1801,44 @@ def forward(self, x: torch.Tensor): LinearModel(n_pca_basis, n_sh_basis, n_gaussians), (torch.ones(n_pca_basis),), ) + + def test_vulkan_backend_sym_size_int(self): + """ + Test the sym_size.int operator with a model that: + 1. Takes an input tensor with shape [1, M, K] + 2. Reshapes it to [M, K] + 3. Applies a linear layer + 4. Reshapes the output back to [1, M, N] + """ + K = 64 # Input feature dimension + N = 32 # Output feature dimension + + class SymSizeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(K, N) + + def forward(self, x): + M = x.size(1) + + reshaped = torch.reshape(x, [M, K]) + output = self.linear(reshaped) + return torch.reshape(output, [1, M, N]) + + sample_inputs = (torch.randn(1, 64, K),) + + batch = Dim("batch", min=1, max=128) + dynamic_shapes = {"x": {1: batch}} + + test_inputs = [ + (torch.randn(1, 32, K),), + (torch.randn(1, 96, K),), + (torch.randn(1, 128, K),), + ] + + self.lower_module_and_test_output( + SymSizeModel(), + sample_inputs, + dynamic_shapes=dynamic_shapes, + test_inputs=test_inputs, + ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index eb949a6ace8..642f7c5f495 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -101,10 +101,6 @@ def is_tensor_node(node: torch.fx.Node) -> bool: """ Returns true if the given node produces a tensor value, or a collection of tensor values """ - # All nodes with tensor values are tagged by the SpecPropPass transform - if "spec" in node.meta: - return True - if "val" not in node.meta: return False