From 1c32be52c4cbf61e8ae7d52ccc3c64c39e9686e8 Mon Sep 17 00:00:00 2001 From: Sicheng Jia Date: Sat, 4 Oct 2025 00:00:31 -0400 Subject: [PATCH] [ET-VK] Miscellaneous fixes (#14801) Collecting fixes for various models/ops in this diff/PR. They have all been squashed into this single change to make it easier to cherry pick. # Fixes ## Wav2Letter Type: Output correctness failure This is caused by a bug in swiftshader, and not reproducible on any other platform. Specifically, the issue is in the softmax shader; the exact cause of the issue is unknown, but it is related to using shared memory within shaders. The workaround for this issue is to use separate shared memory arrays for the shared max and shared sum. ## ConvNeXT Type: Exception during runtime This is caused by an incompatible memory layout being used for mean2d. More technically, the packed dimension of the tensor cannot be one of the dims being reduced. The current operator registry system did not have a way to select valid tensor representations based on the actual arguments of an op. To fix, we have to introduce a mechanism for ops to specify valid representations once a node's arguments are known. Once the model is exported with supported memory layout, the model test passes. ## Inception_V3/ViT Type: Exception during runtime The root cause of this was an interaction betwen the fuse batch norm pass and how `vulkan_preprocess.py` was applying passes. Essentially, the fuse batch norm pass creates a new param node for the fused weight, but after the pass is applied `_copy_module` is used to copy the transformed graph back into the ExportedProgram. However, it seems that _copy_module lowercases the node names without updating the exported program's graph signature. Therefore, subsequent passes couldn't recognize the weight tensor of convolution tensors as a constant/parameter node. The solution was to migrate vulkan_preprocess.py to use the _transform() API instead of using _copy_module. ## DenseNet 161 (w/ dynamic shapes) Type: Output Mismatch Cause: the native_batch_norm op doesn't support dynamic shapes. However, the backend test runner doesn't set the correct compile option to filter ops without dynamic shape support. Differential Revision: [D83703496](https://our.internmc.facebook.com/intern/diff/D83703496/) [ghstack-poisoned] (cherry picked from commit 3f0896a5d9dd70f5c21bf2368640d748192f0238) --- .github/workflows/pull.yml | 7 +- backends/vulkan/_passes/fold_qdq.py | 5 +- backends/vulkan/_passes/fuse_patterns.py | 10 +- backends/vulkan/_passes/fuse_quantized_ops.py | 10 +- .../vulkan/_passes/tag_memory_meta_pass.py | 4 + backends/vulkan/op_registry.py | 93 +++++++++---- .../vulkan/partitioner/vulkan_partitioner.py | 10 +- backends/vulkan/patterns/quantized_linear.py | 12 +- .../vulkan/runtime/graph/ops/glsl/conv2d.glsl | 2 +- .../runtime/graph/ops/glsl/conv2d_dw.glsl | 2 +- .../graph/ops/glsl/conv2d_dw_output_tile.glsl | 4 + .../vulkan/runtime/graph/ops/glsl/full.yaml | 1 + .../runtime/graph/ops/glsl/softmax.glsl | 27 ++-- .../runtime/graph/ops/impl/BatchNorm.cpp | 14 +- .../vulkan/runtime/graph/ops/impl/Permute.cpp | 8 +- .../vulkan/runtime/graph/ops/impl/Pool.cpp | 4 +- .../vulkan/runtime/graph/ops/impl/Squeeze.cpp | 9 +- backends/vulkan/test/TARGETS | 1 - backends/vulkan/test/test_vulkan_passes.py | 70 +--------- backends/vulkan/test/utils.py | 4 +- backends/vulkan/utils.py | 19 ++- backends/vulkan/vulkan_preprocess.py | 59 ++++---- examples/vulkan/export.py | 127 +++++++++++------- 23 files changed, 298 insertions(+), 204 deletions(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index c7f6c126a08..80cf9adc184 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -959,11 +959,16 @@ jobs: PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build # Test models serially - models="mv2 mv3 edsr resnet18 resnet50 dl3" + models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4" for model in $models; do python -m examples.vulkan.export --model_name=$model --test done + # For selected vision models, test with dynamic shapes + models="mv2 resnet18 resnet50 ic3 densenet161" + for model in $models; do + python -m examples.vulkan.export --model_name=$model --test -d + done test-vulkan-operators-linux: name: test-vulkan-operators-linux diff --git a/backends/vulkan/_passes/fold_qdq.py b/backends/vulkan/_passes/fold_qdq.py index 3beccc2205c..a6a5e751c05 100644 --- a/backends/vulkan/_passes/fold_qdq.py +++ b/backends/vulkan/_passes/fold_qdq.py @@ -17,9 +17,8 @@ class FoldQDQPass(ExportPass): valid quant op patterns have already been fused before this pass. """ - def __init__(self, edge_program: torch.export.ExportedProgram): - super(FoldQDQPass, self).__init__() - self.edge_program = edge_program + def __init__(self): + super().__init__() def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: diff --git a/backends/vulkan/_passes/fuse_patterns.py b/backends/vulkan/_passes/fuse_patterns.py index 6ced1f32a7c..1575dd6a4f6 100644 --- a/backends/vulkan/_passes/fuse_patterns.py +++ b/backends/vulkan/_passes/fuse_patterns.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional + import executorch.backends.vulkan.patterns as vk_patterns import torch @@ -13,13 +15,15 @@ class FusePatternsPass(ExportPass): - def __init__(self, exported_program: ExportedProgram) -> None: + def __init__(self) -> None: super().__init__() - self.program = exported_program + self._exported_program: Optional[ExportedProgram] = None def call(self, graph_module: torch.fx.GraphModule): + assert self._exported_program is not None + total_replaced = vk_patterns.replace_all_fusable_subgraphs( - self.program, graph_module + self._exported_program, graph_module ) if total_replaced > 0: diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py index ca9f7541159..bb8cf5f2e64 100644 --- a/backends/vulkan/_passes/fuse_quantized_ops.py +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -211,18 +211,20 @@ def fuse_into_linear_qcnw_node( class FuseQuantizedOpsTransform(ExportPass): - def __init__(self, exported_program: ExportedProgram) -> None: + def __init__(self) -> None: super().__init__() - self.program = exported_program + self._exported_program: Optional[ExportedProgram] = None def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + assert self._exported_program is not None + for node in graph_module.graph.nodes: # Check for linear_qcnw pattern (weight-only quantization) - qcnw_details = matches_linear_qcnw_pattern(self.program, node) + qcnw_details = matches_linear_qcnw_pattern(self._exported_program, node) if qcnw_details is not None: qcnw_method, qcnw_nbits = qcnw_details fuse_into_linear_qcnw_node( - self.program, graph_module, node, qcnw_method, qcnw_nbits + self._exported_program, graph_module, node, qcnw_method, qcnw_nbits ) continue diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index db53cc666a8..8ed71aa1dae 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -230,6 +230,10 @@ def get_arg_tensor_source_repset( """ arg_node = op_node.args[arg_i] + # For non-tensor arguments, return ANY_STORAGE + if not utils.is_tensor_arg_node(arg_node): + return utils.ANY_STORAGE + # Special case for cat - use the first tensor in the list as representative if isinstance(arg_node, list): arg_node = arg_node[0] diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 4c686e0cfc5..8eb47ff467e 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -16,8 +16,6 @@ import torch -from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -48,6 +46,9 @@ class OpFeatures: # Optional check function used during partitioning to determine if a node's # inputs are supported by the operator implementation. "are_node_inputs_supported_fn", + # Optional function to determine valid representation sets for input and outputs + # once a node's actual inputs are known. + "pick_io_storage_fn", ] def __init__( @@ -61,6 +62,7 @@ def __init__( supports_resize: bool = False, supports_prepacking: bool = False, are_node_inputs_supported_fn: Optional[Callable] = allow_node, + pick_io_storage_fn: Optional[Callable] = None, ): self.inputs_storage: utils.TensorRepSetList = utils.TensorRepSetList( inputs_storage if inputs_storage is not None else [] @@ -77,15 +79,21 @@ def __init__( self.supports_prepacking = supports_prepacking self.are_node_inputs_supported_fn = are_node_inputs_supported_fn + self.pick_io_storage_fn = pick_io_storage_fn def make_op_repsets( self, op_node: torch.fx.Node, texture_limits: utils.ImageExtents = utils.DEFAULT_TEXTURE_LIMITS, ) -> utils.OpRepSets: - return utils.OpRepSets( - self.inputs_storage, self.outputs_storage, op_node, texture_limits - ) + inputs_storage = self.inputs_storage + outputs_storage = self.outputs_storage + if self.pick_io_storage_fn is not None: + i_storage, o_storage = self.pick_io_storage_fn(op_node) + inputs_storage = utils.TensorRepSetList(i_storage) + outputs_storage = utils.TensorRepSetList(o_storage) + + return utils.OpRepSets(inputs_storage, outputs_storage, op_node, texture_limits) ####################### @@ -410,28 +418,16 @@ def register_softmax_op(): ) def register_reduce_op(): def check_reduce_node(node: torch.fx.Node) -> bool: + # Only one argument implies that the reduction is over the entire tensor, which + # is not supported yet. + if len(node.args) == 1: + return False + dim_list = node.args[1] + # Only 1D and 2D reductions are supported at the moment. if isinstance(dim_list, list) and len(dim_list) > 2: return False - if isinstance(dim_list, list) and len(dim_list) == 2: - # Try to get the memory layout for this node - try: - memory_layout = utils.get_node_memory_layout(node) - - # If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension - if ( - memory_layout is not None - and memory_layout != VkMemoryLayout.DEFAULT_LAYOUT - ): - # For now only default layout is supported for 2D reduction. - # Because we can't determine if the input is NCHW or NHWC here, - # assume the reduction dimension is packed so we cannot support it. - return False - except (AssertionError, KeyError, AttributeError): - # If we can't get memory layout information, we'll assume the dims aren't packed - pass - def try_find_keepdim_arg(node: torch.fx.Node) -> bool: for arg in node.args: if isinstance(arg, bool): @@ -446,10 +442,41 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool: return True + def pick_io_storage_for_reduce(node: torch.fx.Node): + inputs_storage = utils.ANY_TEXTURE + outputs_storage = utils.ANY_TEXTURE + + input_tensor = node.args[0] + ndim = input_tensor.meta["val"].ndim + dim_list = node.args[1] + if isinstance(dim_list, list) and len(dim_list) == 2: + reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim) + reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim) + + possible_packed_dims = {0, 1, 2} + possible_packed_dims.discard(reduce_dim1_whcn) + possible_packed_dims.discard(reduce_dim2_whcn) + + packed_dim = possible_packed_dims.pop() + assert packed_dim in [0, 1, 2] + + if packed_dim == 0: + inputs_storage = utils.WIDTH_PACKED_TEXTURE + outputs_storage = utils.WIDTH_PACKED_TEXTURE + elif packed_dim == 1: + inputs_storage = utils.HEIGHT_PACKED_TEXTURE + outputs_storage = utils.HEIGHT_PACKED_TEXTURE + else: + inputs_storage = utils.CHANNELS_PACKED_TEXTURE + outputs_storage = utils.CHANNELS_PACKED_TEXTURE + + return inputs_storage, outputs_storage + return OpFeatures( inputs_storage=utils.ANY_TEXTURE, supports_resize=True, are_node_inputs_supported_fn=check_reduce_node, + pick_io_storage_fn=pick_io_storage_for_reduce, ) @@ -474,6 +501,23 @@ def register_2d_pool_op(): ] ) def register_convolution_op(): + def check_conv_node(node: torch.fx.Node) -> bool: + x = node.args[0] + x_shape = x.meta["val"].size() + # 4-D input implies 2D convolution + if len(x_shape) == 4: + batches = x.meta["val"].size()[0] + if batches != 1: + return False + # 3-D input implies 1D convolution + if len(x_shape) == 3: + transpose = node.args[6] + # Transposed 1D convolution is not supported yet + if transpose: + return False + + return True + return OpFeatures( inputs_storage=[ utils.CHANNELS_PACKED_TEXTURE, # input @@ -490,6 +534,7 @@ def register_convolution_op(): ], supports_resize=True, supports_prepacking=True, + are_node_inputs_supported_fn=check_conv_node, ) @@ -666,6 +711,7 @@ def register_ported_ops_with_prepacking(): return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, supports_prepacking=True, + supports_resize=True, ) @@ -696,6 +742,7 @@ def register_ported_ops_with_prepacking_all_dims(): return OpFeatures( inputs_storage=utils.ANY_TEXTURE, supports_prepacking=True, + supports_resize=True, ) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index e5b2d0f7864..0bdc16616ef 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -36,7 +36,7 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.utils import tag_constant_data +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram @@ -254,9 +254,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901 self.log_skip(node, "permute node of non compatible linear node") return False - is_in_local_scalar_dense_chain, dst_node_is_compatible = ( - self.is_in_local_scalar_dense_chain(node) - ) + ( + is_in_local_scalar_dense_chain, + dst_node_is_compatible, + ) = self.is_in_local_scalar_dense_chain(node) if is_in_local_scalar_dense_chain and dst_node_is_compatible: return True elif is_in_local_scalar_dense_chain: @@ -419,6 +420,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: logger.info(f"Found {pl} Vulkan subgraphs to be partitioned.") tag_constant_data(exported_program) + tag_mutated_buffer(exported_program) return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 882d0d41e6d..374e29c634d 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -92,9 +92,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None: return # Identify input node - self.fp_input_node, self.quantize_input_node, dq_node = ( - utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) - ) + ( + self.fp_input_node, + self.quantize_input_node, + dq_node, + ) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) assert self.fp_input_node is not None self.all_nodes.append(self.fp_input_node) @@ -386,7 +388,7 @@ def make_linear_dq8ca_q4gsw_op( weight_sums_node = create_constant_placeholder( exp_program=ep, graph=graph_module.graph, - kind=InputKind.CONSTANT_TENSOR, + kind=InputKind.PARAMETER, name=sums_name, data=sum_per_quant_group, ) @@ -429,7 +431,7 @@ def make_linear_q8ta_q8csw_custom_op( weight_sums_node = create_constant_placeholder( exp_program=ep, graph=graph_module.graph, - kind=InputKind.CONSTANT_TENSOR, + kind=InputKind.PARAMETER, name=sums_name, data=sum_per_output_channel, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl index 0f5dbc41273..88746c5594e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl @@ -60,7 +60,7 @@ void main() { int num_steps = ((-ipos.y) + dilation.y - 1) / dilation.y; start.y = ipos.y + num_steps * dilation.y; } - const ivec2 end = min(ipos + overlay_region.xy, ivec2(in_sizes.xy)); + const ivec2 end = min(ipos + overlay_region.xy, in_sizes.xy); // Compute the start of the kernel based on how far we are skipping ahead when // reading the input. Note that these are "canonical" indices. ivec2 kstart = (start - ipos) / dilation; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl index 02fbef29b75..9089f87d658 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl @@ -54,7 +54,7 @@ void main() { // Compute the start and end of the input indices to load. Padding is assumed // to be constant 0 padding, so reads from the padding region are skipped. const ivec2 start = ipos; - const ivec2 end = ipos + overlay_region.xy; + const ivec2 end = min(ipos + overlay_region.xy, in_sizes.xy); VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); int kx = 0; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 19250419baf..7448b042cad 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -97,6 +97,10 @@ void main() { for (int y = start.y, i = 0; i < TILE_SIZE + BATCH_SIZE_Y - 1; y += dilation.y, i++) { for (int x = start.x, j = 0; j < TILE_SIZE + BATCH_SIZE_X - 1; x += dilation.x, j++) { in_texels[j] = texelFetch(t_in, ivec3(x, y, pos.z), 0); + // Set to zero if reading out of bounds + if (any(greaterThanEqual(ivec2(x, y), in_sizes.xy))) { + in_texels[j] = VEC4_T(0); + } } // from 2nd iteration onwards accumulate dot product in 2nd sum diff --git a/backends/vulkan/runtime/graph/ops/glsl/full.yaml b/backends/vulkan/runtime/graph/ops/glsl/full.yaml index eff78a7938d..1a5b0cb235e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/full.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/full.yaml @@ -14,5 +14,6 @@ full: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 shader_variants: - NAME: full diff --git a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl index d35492bc367..86a2229c416 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl @@ -42,7 +42,8 @@ layout(constant_id = 5) const int group_dim = 1; // work group will write into its assigned element in the shared array. #define MAX_NTHREADS 16 -shared vec4 shared_vecs[MAX_NTHREADS]; +shared vec4 shared_max[MAX_NTHREADS]; +shared vec4 shared_sum[MAX_NTHREADS]; #include "indexing_utils.h" @@ -102,13 +103,13 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { max_elements = max(max_elements, load_texel(tin, scan_pos)); } - shared_vecs[smi] = max_elements; + shared_max[smi] = max_elements; barrier(); // Iterate over the partial maximums to obtain the overall maximum group_i = tid.y * NWORKERS; - max_elements = shared_vecs[group_i++]; + max_elements = shared_max[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - max_elements = max(max_elements, shared_vecs[group_i]); + max_elements = max(max_elements, shared_max[group_i]); } scan_pos[reduce_dim] = tid.x; @@ -118,13 +119,13 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { denominators += exp(load_texel(tin, scan_pos) - max_elements); } - shared_vecs[smi] = denominators; + shared_sum[smi] = denominators; barrier(); // Iterate over the partial sums to obtain the overall sum group_i = tid.y * NWORKERS; - denominators = shared_vecs[group_i++]; + denominators = shared_sum[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - denominators += shared_vecs[group_i]; + denominators += shared_sum[group_i]; } // Determine if there are any padding elements in the final texel of the @@ -184,13 +185,13 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { max_elements.x = max(intex[i], max_elements.x); } } - shared_vecs[smi] = max_elements; + shared_max[smi] = max_elements; barrier(); // Iterate over the partial maximums to obtain the overall maximum group_i = tid.y * NWORKERS; - max_elements = shared_vecs[group_i++]; + max_elements = shared_max[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - max_elements = max(max_elements, shared_vecs[group_i]); + max_elements = max(max_elements, shared_max[group_i]); } // Each element of the texel is itself a partial maximum; iterate over the // texel to find the actual maximum @@ -214,13 +215,13 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { denominators.x += exp(intex[i] - max_element); } } - shared_vecs[smi] = denominators; + shared_sum[smi] = denominators; barrier(); // Iterate over the partial sums to obtain the overall sum group_i = tid.y * NWORKERS; - denominators = shared_vecs[group_i++]; + denominators = shared_sum[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - denominators += shared_vecs[group_i]; + denominators += shared_sum[group_i]; } // Reduce over the accumulated texel to find the overall sum float denominator = 0; diff --git a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp index 757afd06849..a6dd8f07f53 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp @@ -19,6 +19,18 @@ namespace vkcompute { +void resize_batch_norm_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + + // For batch norm, output dimensions are the same as input dimensions + std::vector new_out_sizes = graph->sizes_of(self); + graph->virtual_resize(out, new_out_sizes); +} + ValueRef check_and_prepack_arg( ComputeGraph& graph, ValueRef arg_ref, @@ -101,7 +113,7 @@ void add_native_batch_norm_node( // Resize Args {}, // Resizing Logic - nullptr)); + resize_batch_norm_node)); } void native_batch_norm(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index 9ac4c963bc3..329620e80e6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -109,11 +109,15 @@ void add_permute_node( { IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims); const int32_t permute_ndim = - utils::safe_downcast(permute_dims_ptr->size()); + utils::safe_downcast(permute_dims_ptr->size()); for (int32_t nchw_i = permute_ndim - 1, whcn_i = 0; nchw_i >= 0; nchw_i--, whcn_i++) { - const int32_t permute_dim_nchw = permute_dims_ptr->at(nchw_i); + int32_t permute_dim_nchw = + utils::safe_downcast(permute_dims_ptr->at(nchw_i)); + if (permute_dim_nchw < 0) { + permute_dim_nchw += permute_ndim; + } const int32_t permute_dim_whcn = permute_ndim - 1 - permute_dim_nchw; whcn_permute_dims[whcn_i] = permute_dim_whcn; diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index 250fcdd5490..879f59667d6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -137,7 +137,7 @@ void max_pool2d(ComputeGraph& graph, const std::vector& args) { struct DivisorParams final { int32_t divisor_override; - bool count_include_pad; + int32_t count_include_pad; }; DivisorParams create_divisor_params( @@ -148,7 +148,7 @@ DivisorParams create_divisor_params( graph.val_is_int(divisor_override) ? static_cast(graph.get_int(divisor_override)) : 0, - graph.get_bool(count_include_pad)}; + int32_t(graph.get_bool(count_include_pad))}; } void add_avg_pool2d_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp index 13801b45cc7..e2b73b2f3f2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp @@ -32,8 +32,13 @@ void add_squeeze_copy_dims_node( // 2. Squeeze outter most dim // For these cases, just pass input to output via clone. for (int i = 0; i < dims.size(); ++i) { - if (dims.at(i) != 0 && in_sizes.at(dims.at(i)) == 1) { - squeeze_dims.push_back(dims.at(i)); + // adjust negative dims + int64_t dim_val = dims.at(i); + if (dim_val < 0) { + dim_val += in_dim; + } + if (dims.at(i) != 0 && in_sizes.at(dim_val) == 1) { + squeeze_dims.push_back(dim_val); } } if (squeeze_dims.size() == 0) { diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index 53fad86f90c..ee296a4f68f 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -34,7 +34,6 @@ python_unittest( deps = [ "//caffe2:torch", "//executorch/backends/vulkan/_passes:vulkan_passes", - "//executorch/backends/vulkan/quantizer:vulkan_quantizer", "//executorch/backends/vulkan:vulkan_preprocess", "//pytorch/ao:torchao", # @manual ] diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 4a30ab6c2de..438126a179f 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -3,15 +3,8 @@ import torch -from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform -from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass -from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( - get_symmetric_quantization_config, - VulkanQuantizer, -) - from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge from executorch.exir.backend.canonical_partitioners.config_partitioner import ( @@ -94,66 +87,6 @@ def op_node_count(graph_module: torch.fx.GraphModule, canonical_op_name: str) -> class TestVulkanPasses(unittest.TestCase): - def test_fuse_int8pack_mm(self): - K = 256 - N = 256 - model = SingleLinearModule(K, N) - sample_inputs = model.get_sample_inputs() - - quantizer = VulkanQuantizer() - quantizer.set_global( - get_symmetric_quantization_config(is_dynamic=False, weight_bits=8) - ) - - edge_manager = quantize_and_lower_module( - model, - sample_inputs, - quantizer, - ) - - ep = edge_manager._edge_programs["forward"] - edge_manager.transform( - [ - AddmmToLinearTransform(), - FuseQuantizedOpsTransform(ep), - ] - ) - - gm = ep.graph_module - - self.assertEqual(op_node_count(gm, "_weight_int8pack_mm.default"), 1) - self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) - - def test_fuse_linear_qcs4w(self): - K = 256 - N = 256 - model = SingleLinearModule(K, N) - sample_inputs = model.get_sample_inputs() - - quantizer = VulkanQuantizer() - quantizer.set_global( - get_symmetric_quantization_config(is_dynamic=False, weight_bits=4) - ) - - edge_manager = quantize_and_lower_module( - model, - sample_inputs, - quantizer, - ) - - ep = edge_manager._edge_programs["forward"] - edge_manager.transform( - [ - AddmmToLinearTransform(), - FuseQuantizedOpsTransform(ep), - ] - ) - - gm = ep.graph_module - - self.assertEqual(op_node_count(gm, "linear_qcs4w.default"), 1) - self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) - def test_fuse_rotary_emb(self): """Test conversion of rotary embedding pattern to et_vk.apply_rotary_emb custom op.""" @@ -238,7 +171,8 @@ def _reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor): # Apply the rotary embedding pass ep = edge_manager._edge_programs["forward"] - rotary_pass = FusePatternsPass(ep) + rotary_pass = FusePatternsPass() + rotary_pass._exported_program = ep result = rotary_pass.call(ep.graph_module) # Verify that the pass was successful diff --git a/backends/vulkan/test/utils.py b/backends/vulkan/test/utils.py index 41c1d92bd00..e3d15de644f 100644 --- a/backends/vulkan/test/utils.py +++ b/backends/vulkan/test/utils.py @@ -90,7 +90,9 @@ def export_model_to_vulkan( qmode=QuantizationMode.NONE, ): compile_options = {} - exported_graph = get_exported_graph(model, sample_inputs, qmode=qmode) + exported_graph = get_exported_graph( + model, sample_inputs, dynamic_shapes=dynamic_shapes, qmode=qmode + ) program = export( exported_graph, sample_inputs, diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 96f200eecbc..892b667a7bb 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -128,7 +128,7 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: is_get_attr_node(node) or is_param(program, node) or is_buffer(program, node) - or is_constant(program, node) + or is_lifted_tensor_constant(program, node) ) @@ -206,6 +206,8 @@ def is_tensor_arg_node(node: Any) -> bool: if isinstance(node, torch.fx.Node): return is_tensor_node(node) elif isinstance(node, (list, tuple)): + if len(node) == 0: + return False return all(is_tensor_node(n) for n in node) return False @@ -1218,6 +1220,16 @@ def is_in_8bit_range(tensor: torch.Tensor) -> bool: ## +def nchw_dim_to_whcn_dim(nchw_dim: int, ndim: int) -> int: + # Handle negative indices for nchw_dim + if nchw_dim < 0: + nchw_dim += ndim + + assert nchw_dim >= 0 and nchw_dim < ndim + whcn_dim = (ndim - 1) - nchw_dim + return whcn_dim + + def get_tensor_val_str(tensor_val: FakeTensor) -> str: return f"{tensor_val.dtype}: {tensor_val.shape}" @@ -1269,6 +1281,7 @@ def update_program_state_dict( updated_tensor: torch.Tensor, ) -> None: target_name = None + kind = None # Iterate over all the tensors in the graph signature, and find # the one corresponding to the parameter/buffer name for input_ in program.graph_signature.input_specs: @@ -1277,6 +1290,7 @@ def update_program_state_dict( and isinstance(input_.arg, TensorArgument) and input_.arg.name == buffer_name ): + kind = input_.kind target_name = input_.target break @@ -1286,6 +1300,9 @@ def update_program_state_dict( ), f"could not find {buffer_name} in source program signature" assert target_name in program.state_dict, f"could not find {target_name}" + if kind == InputKind.PARAMETER: + updated_tensor = torch.nn.Parameter(updated_tensor, requires_grad=False) + # Finally, overwrite the current tensor with updated tensor program.state_dict[target_name] = updated_tensor diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 95da66494e0..d59bd9eff7d 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -8,7 +8,7 @@ from functools import partial -from typing import Any, Dict, final, List +from typing import Any, Callable, Dict, final, List import executorch.backends.vulkan.utils as utils @@ -55,7 +55,9 @@ from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from executorch.exir.program._program import _copy_module +from executorch.exir.program._program import _transform + +from torch._export.verifier import Verifier from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, @@ -64,28 +66,34 @@ DEFAULT_DEBUG_HANDLE = 65535 +class _any_op(Verifier): + # Set training dialect to skip functional check in base verifier + dialect = "TRAINING" + + def allowed_op_types(self): + return (Callable,) + + # pyre-ignore def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: for p in passes: - if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): - new_gm = program.graph_module - # This is a workaround to allow the memory planning pass to work without - # having to first apply ToOutVarPass(). See the `greedy()` function in - # `exir.memory_planning`; if this attribute isn't set, assertions in - # `collect_spec_from_nodes()` will fail. - if isinstance(p, MemoryPlanningPass): - new_gm.encounter_to_out_var_failure = True - - new_gm_res = p(new_gm) - assert new_gm_res is not None - new_gm = new_gm_res.graph_module - + if isinstance(p, MemoryPlanningPass) and hasattr(p, "run"): + p.run(program.graph_module) + + elif issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): + # Some passes require the ep to be provided. However, since the ep may be + # updated with each pass applied, the ep must be set right before calling + # the pass. _exported_program is the attribute used by XNNPACK and Vulkan + # passes to store the exported program. + if hasattr(p, "_exported_program"): + p._exported_program = program + + program = _transform(program, p, override_verifiers=[_any_op]) # See the application of this function in exir/program/_program.py for more # details on why this step is necessary. if isinstance(p, SpecPropPass): - p.update_placeholder_tensor_specs(program, new_gm) + p.update_placeholder_tensor_specs(program, program.graph_module) - _copy_module(program.graph_module, new_gm) else: program = p(program) @@ -158,16 +166,16 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ - FusePatternsPass(program), - RemoveRedundantOpsTransform(), + FuseBatchNormPass(program), + FusePatternsPass(), + FuseClampPass(), AddmmToLinearTransform(), - FuseQuantizedOpsTransform(program), - FoldQDQPass(program), + RemoveRedundantOpsTransform(), + FuseQuantizedOpsTransform(), + FoldQDQPass(), SqueezeUnsqueezeInputs(), FuseViewCopyTransform(), ViewCopyToSqueezeUnsqueezePass(), - FuseBatchNormPass(program), - FuseClampPass(), ], ) @@ -213,6 +221,11 @@ def preprocess( # noqa: C901 mem_planning_suite = MemoryPlanningAlgorithmSuite( algo_list=[greedy_memory_planning] ) + # This is a workaround to allow the memory planning pass to work without having + # to first apply ToOutVarPass(). See the `greedy()` function in + # `exir.memory_planning`; if this attribute isn't set, assertions in + # `collect_spec_from_nodes()` will fail. + program.graph_module.encounter_to_out_var_failure = True program = apply_passes( program, [ diff --git a/examples/vulkan/export.py b/examples/vulkan/export.py index 4d85d83c862..a1bb42af6c1 100644 --- a/examples/vulkan/export.py +++ b/examples/vulkan/export.py @@ -14,22 +14,18 @@ import backends.vulkan.test.utils as test_utils import torch +import torchvision -from executorch.backends.transforms.convert_dtype_pass import I64toI32 from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner from executorch.devtools import BundledProgram from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.devtools.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, ) -from executorch.exir import ( - EdgeCompileConfig, - ExecutorchBackendConfig, - to_edge_transform_and_lower, -) +from executorch.exir import to_edge_transform_and_lower from executorch.extension.export_util.utils import save_pte_program from executorch.extension.pytree import tree_flatten -from torch.export import export +from torch.export import Dim, export from ..models import MODEL_NAME_TO_MODEL from ..models.model_factory import EagerModelFactory @@ -38,6 +34,67 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) +def is_vision_model(model_name): + if model_name in [ + # These models are also registered in examples/models + "dl3", + "edsr", + "mv2", + "mv3", + "vit", + "ic3", + "ic4", + "resnet18", + "resnet50", + # These models are not registered in examples/models but are available via + # torchvision + "convnext_small", + "densenet161", + "shufflenet_v2_x1_0", + ]: + return True + + return False + + +def get_vision_model_sample_input(): + return (torch.randn(1, 3, 224, 224),) + + +def get_vision_model_dynamic_shapes(): + return ( + { + 2: Dim("height", min=1, max=16) * 16, + 3: Dim("width", min=1, max=16) * 16, + }, + ) + + +def init_model(model_name): + if model_name == "convnext_small": + return torchvision.models.convnext_small() + if model_name == "densenet161": + return torchvision.models.densenet161() + if model_name == "shufflenet_v2_x1_0": + return torchvision.models.shufflenet_v2_x1_0() + + return None + + +def get_sample_inputs(model_name): + if is_vision_model(model_name): + return get_vision_model_sample_input() + + return None + + +def get_dynamic_shapes(model_name): + if is_vision_model(model_name): + return get_vision_model_dynamic_shapes() + + return None + + def main() -> None: logger = logging.getLogger("") logger.setLevel(logging.INFO) @@ -68,21 +125,6 @@ def main() -> None: help="whether to export with strict mode. Default is True", ) - parser.add_argument( - "-a", - "--segment_alignment", - required=False, - help="specify segment alignment in hex. Default is 0x1000. Use 0x4000 for iOS", - ) - - parser.add_argument( - "-e", - "--external_constants", - action=argparse.BooleanOptionalAction, - default=False, - help="Save constants in external .ptd file. Default is False", - ) - parser.add_argument( "-d", "--dynamic", @@ -119,31 +161,35 @@ def main() -> None: args = parser.parse_args() - if args.model_name not in MODEL_NAME_TO_MODEL: - raise RuntimeError( - f"Model {args.model_name} is not a valid name. " - f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." + if args.model_name in MODEL_NAME_TO_MODEL: + model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[args.model_name] ) + else: + model = init_model(args.model_name) + example_inputs = get_sample_inputs(args.model_name) + dynamic_shapes = get_dynamic_shapes(args.model_name) if args.dynamic else None - model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( - *MODEL_NAME_TO_MODEL[args.model_name] - ) + if model is None: + raise RuntimeError( + f"Model {args.model_name} is not a valid name. " + f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." + ) # Prepare model model.eval() # Setup compile options compile_options = {} - if args.dynamic or dynamic_shapes is not None: + if args.dynamic: compile_options["require_dynamic_shapes"] = True + # Try to manually get the dynamic shapes for the model if not set + if dynamic_shapes is None: + dynamic_shapes = get_dynamic_shapes(args.model_name) + if args.force_fp16: compile_options["force_fp16"] = True - # Configure Edge compilation - edge_compile_config = EdgeCompileConfig( - _skip_dim_order=False, # Proper handling for Vulkan memory format - ) - logging.info(f"Exporting model {args.model_name} with Vulkan delegate") # Export the model using torch.export @@ -157,10 +203,6 @@ def main() -> None: # Transform and lower with Vulkan partitioner edge_program = to_edge_transform_and_lower( program, - compile_config=edge_compile_config, - transform_passes=[ - I64toI32(edge_compile_config._skip_dim_order), - ], partitioner=[VulkanPartitioner(compile_options)], generate_etrecord=args.etrecord, ) @@ -169,13 +211,8 @@ def main() -> None: f"Exported and lowered graph:\n{edge_program.exported_program().graph}" ) - # Configure backend options - backend_config = ExecutorchBackendConfig(external_constants=args.external_constants) - if args.segment_alignment is not None: - backend_config.segment_alignment = int(args.segment_alignment, 16) - # Create executorch program - exec_prog = edge_program.to_executorch(config=backend_config) + exec_prog = edge_program.to_executorch() # Save ETRecord if requested if args.etrecord: