diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index c15fadd102f..845cb5d8631 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -970,11 +970,16 @@ jobs: PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build # Test models serially - models="mv2 mv3 edsr resnet18 resnet50 dl3" + models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4" for model in $models; do python -m examples.vulkan.export --model_name=$model --test done + # For selected vision models, test with dynamic shapes + models="mv2 resnet18 resnet50 ic3 densenet161" + for model in $models; do + python -m examples.vulkan.export --model_name=$model --test -d + done test-vulkan-operators-linux: name: test-vulkan-operators-linux diff --git a/backends/vulkan/_passes/fold_qdq.py b/backends/vulkan/_passes/fold_qdq.py index 3beccc2205c..a6a5e751c05 100644 --- a/backends/vulkan/_passes/fold_qdq.py +++ b/backends/vulkan/_passes/fold_qdq.py @@ -17,9 +17,8 @@ class FoldQDQPass(ExportPass): valid quant op patterns have already been fused before this pass. """ - def __init__(self, edge_program: torch.export.ExportedProgram): - super(FoldQDQPass, self).__init__() - self.edge_program = edge_program + def __init__(self): + super().__init__() def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: diff --git a/backends/vulkan/_passes/fuse_patterns.py b/backends/vulkan/_passes/fuse_patterns.py index 6ced1f32a7c..1575dd6a4f6 100644 --- a/backends/vulkan/_passes/fuse_patterns.py +++ b/backends/vulkan/_passes/fuse_patterns.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional + import executorch.backends.vulkan.patterns as vk_patterns import torch @@ -13,13 +15,15 @@ class FusePatternsPass(ExportPass): - def __init__(self, exported_program: ExportedProgram) -> None: + def __init__(self) -> None: super().__init__() - self.program = exported_program + self._exported_program: Optional[ExportedProgram] = None def call(self, graph_module: torch.fx.GraphModule): + assert self._exported_program is not None + total_replaced = vk_patterns.replace_all_fusable_subgraphs( - self.program, graph_module + self._exported_program, graph_module ) if total_replaced > 0: diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py index ca9f7541159..bb8cf5f2e64 100644 --- a/backends/vulkan/_passes/fuse_quantized_ops.py +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -211,18 +211,20 @@ def fuse_into_linear_qcnw_node( class FuseQuantizedOpsTransform(ExportPass): - def __init__(self, exported_program: ExportedProgram) -> None: + def __init__(self) -> None: super().__init__() - self.program = exported_program + self._exported_program: Optional[ExportedProgram] = None def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + assert self._exported_program is not None + for node in graph_module.graph.nodes: # Check for linear_qcnw pattern (weight-only quantization) - qcnw_details = matches_linear_qcnw_pattern(self.program, node) + qcnw_details = matches_linear_qcnw_pattern(self._exported_program, node) if qcnw_details is not None: qcnw_method, qcnw_nbits = qcnw_details fuse_into_linear_qcnw_node( - self.program, graph_module, node, qcnw_method, qcnw_nbits + self._exported_program, graph_module, node, qcnw_method, qcnw_nbits ) continue diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index db53cc666a8..8ed71aa1dae 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -230,6 +230,10 @@ def get_arg_tensor_source_repset( """ arg_node = op_node.args[arg_i] + # For non-tensor arguments, return ANY_STORAGE + if not utils.is_tensor_arg_node(arg_node): + return utils.ANY_STORAGE + # Special case for cat - use the first tensor in the list as representative if isinstance(arg_node, list): arg_node = arg_node[0] diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index a92b3b11f6f..63b57a0e79c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -16,8 +16,6 @@ import torch -from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -48,6 +46,9 @@ class OpFeatures: # Optional check function used during partitioning to determine if a node's # inputs are supported by the operator implementation. "are_node_inputs_supported_fn", + # Optional function to determine valid representation sets for input and outputs + # once a node's actual inputs are known. + "pick_io_storage_fn", ] def __init__( @@ -61,6 +62,7 @@ def __init__( supports_resize: bool = False, supports_prepacking: bool = False, are_node_inputs_supported_fn: Optional[Callable] = allow_node, + pick_io_storage_fn: Optional[Callable] = None, ): self.inputs_storage: utils.TensorRepSetList = utils.TensorRepSetList( inputs_storage if inputs_storage is not None else [] @@ -77,15 +79,21 @@ def __init__( self.supports_prepacking = supports_prepacking self.are_node_inputs_supported_fn = are_node_inputs_supported_fn + self.pick_io_storage_fn = pick_io_storage_fn def make_op_repsets( self, op_node: torch.fx.Node, texture_limits: utils.ImageExtents = utils.DEFAULT_TEXTURE_LIMITS, ) -> utils.OpRepSets: - return utils.OpRepSets( - self.inputs_storage, self.outputs_storage, op_node, texture_limits - ) + inputs_storage = self.inputs_storage + outputs_storage = self.outputs_storage + if self.pick_io_storage_fn is not None: + i_storage, o_storage = self.pick_io_storage_fn(op_node) + inputs_storage = utils.TensorRepSetList(i_storage) + outputs_storage = utils.TensorRepSetList(o_storage) + + return utils.OpRepSets(inputs_storage, outputs_storage, op_node, texture_limits) ####################### @@ -410,28 +418,16 @@ def register_softmax_op(): ) def register_reduce_op(): def check_reduce_node(node: torch.fx.Node) -> bool: + # Only one argument implies that the reduction is over the entire tensor, which + # is not supported yet. + if len(node.args) == 1: + return False + dim_list = node.args[1] + # Only 1D and 2D reductions are supported at the moment. if isinstance(dim_list, list) and len(dim_list) > 2: return False - if isinstance(dim_list, list) and len(dim_list) == 2: - # Try to get the memory layout for this node - try: - memory_layout = utils.get_node_memory_layout(node) - - # If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension - if ( - memory_layout is not None - and memory_layout != VkMemoryLayout.DEFAULT_LAYOUT - ): - # For now only default layout is supported for 2D reduction. - # Because we can't determine if the input is NCHW or NHWC here, - # assume the reduction dimension is packed so we cannot support it. - return False - except (AssertionError, KeyError, AttributeError): - # If we can't get memory layout information, we'll assume the dims aren't packed - pass - def try_find_keepdim_arg(node: torch.fx.Node) -> bool: for arg in node.args: if isinstance(arg, bool): @@ -446,10 +442,41 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool: return True + def pick_io_storage_for_reduce(node: torch.fx.Node): + inputs_storage = utils.ANY_TEXTURE + outputs_storage = utils.ANY_TEXTURE + + input_tensor = node.args[0] + ndim = input_tensor.meta["val"].ndim + dim_list = node.args[1] + if isinstance(dim_list, list) and len(dim_list) == 2: + reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim) + reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim) + + possible_packed_dims = {0, 1, 2} + possible_packed_dims.discard(reduce_dim1_whcn) + possible_packed_dims.discard(reduce_dim2_whcn) + + packed_dim = possible_packed_dims.pop() + assert packed_dim in [0, 1, 2] + + if packed_dim == 0: + inputs_storage = utils.WIDTH_PACKED_TEXTURE + outputs_storage = utils.WIDTH_PACKED_TEXTURE + elif packed_dim == 1: + inputs_storage = utils.HEIGHT_PACKED_TEXTURE + outputs_storage = utils.HEIGHT_PACKED_TEXTURE + else: + inputs_storage = utils.CHANNELS_PACKED_TEXTURE + outputs_storage = utils.CHANNELS_PACKED_TEXTURE + + return inputs_storage, outputs_storage + return OpFeatures( inputs_storage=utils.ANY_TEXTURE, supports_resize=True, are_node_inputs_supported_fn=check_reduce_node, + pick_io_storage_fn=pick_io_storage_for_reduce, ) @@ -474,6 +501,23 @@ def register_2d_pool_op(): ] ) def register_convolution_op(): + def check_conv_node(node: torch.fx.Node) -> bool: + x = node.args[0] + x_shape = x.meta["val"].size() + # 4-D input implies 2D convolution + if len(x_shape) == 4: + batches = x.meta["val"].size()[0] + if batches != 1: + return False + # 3-D input implies 1D convolution + if len(x_shape) == 3: + transpose = node.args[6] + # Transposed 1D convolution is not supported yet + if transpose: + return False + + return True + return OpFeatures( inputs_storage=[ utils.CHANNELS_PACKED_TEXTURE, # input @@ -490,6 +534,7 @@ def register_convolution_op(): ], supports_resize=True, supports_prepacking=True, + are_node_inputs_supported_fn=check_conv_node, ) @@ -716,6 +761,7 @@ def register_ported_ops_with_prepacking(): return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, supports_prepacking=True, + supports_resize=True, ) @@ -746,6 +792,7 @@ def register_ported_ops_with_prepacking_all_dims(): return OpFeatures( inputs_storage=utils.ANY_TEXTURE, supports_prepacking=True, + supports_resize=True, ) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index e5b2d0f7864..0bdc16616ef 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -36,7 +36,7 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.utils import tag_constant_data +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram @@ -254,9 +254,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901 self.log_skip(node, "permute node of non compatible linear node") return False - is_in_local_scalar_dense_chain, dst_node_is_compatible = ( - self.is_in_local_scalar_dense_chain(node) - ) + ( + is_in_local_scalar_dense_chain, + dst_node_is_compatible, + ) = self.is_in_local_scalar_dense_chain(node) if is_in_local_scalar_dense_chain and dst_node_is_compatible: return True elif is_in_local_scalar_dense_chain: @@ -419,6 +420,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: logger.info(f"Found {pl} Vulkan subgraphs to be partitioned.") tag_constant_data(exported_program) + tag_mutated_buffer(exported_program) return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 882d0d41e6d..374e29c634d 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -92,9 +92,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None: return # Identify input node - self.fp_input_node, self.quantize_input_node, dq_node = ( - utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) - ) + ( + self.fp_input_node, + self.quantize_input_node, + dq_node, + ) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) assert self.fp_input_node is not None self.all_nodes.append(self.fp_input_node) @@ -386,7 +388,7 @@ def make_linear_dq8ca_q4gsw_op( weight_sums_node = create_constant_placeholder( exp_program=ep, graph=graph_module.graph, - kind=InputKind.CONSTANT_TENSOR, + kind=InputKind.PARAMETER, name=sums_name, data=sum_per_quant_group, ) @@ -429,7 +431,7 @@ def make_linear_q8ta_q8csw_custom_op( weight_sums_node = create_constant_placeholder( exp_program=ep, graph=graph_module.graph, - kind=InputKind.CONSTANT_TENSOR, + kind=InputKind.PARAMETER, name=sums_name, data=sum_per_output_channel, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl index 0f5dbc41273..88746c5594e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl @@ -60,7 +60,7 @@ void main() { int num_steps = ((-ipos.y) + dilation.y - 1) / dilation.y; start.y = ipos.y + num_steps * dilation.y; } - const ivec2 end = min(ipos + overlay_region.xy, ivec2(in_sizes.xy)); + const ivec2 end = min(ipos + overlay_region.xy, in_sizes.xy); // Compute the start of the kernel based on how far we are skipping ahead when // reading the input. Note that these are "canonical" indices. ivec2 kstart = (start - ipos) / dilation; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl index 02fbef29b75..9089f87d658 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl @@ -54,7 +54,7 @@ void main() { // Compute the start and end of the input indices to load. Padding is assumed // to be constant 0 padding, so reads from the padding region are skipped. const ivec2 start = ipos; - const ivec2 end = ipos + overlay_region.xy; + const ivec2 end = min(ipos + overlay_region.xy, in_sizes.xy); VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); int kx = 0; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 19250419baf..7448b042cad 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -97,6 +97,10 @@ void main() { for (int y = start.y, i = 0; i < TILE_SIZE + BATCH_SIZE_Y - 1; y += dilation.y, i++) { for (int x = start.x, j = 0; j < TILE_SIZE + BATCH_SIZE_X - 1; x += dilation.x, j++) { in_texels[j] = texelFetch(t_in, ivec3(x, y, pos.z), 0); + // Set to zero if reading out of bounds + if (any(greaterThanEqual(ivec2(x, y), in_sizes.xy))) { + in_texels[j] = VEC4_T(0); + } } // from 2nd iteration onwards accumulate dot product in 2nd sum diff --git a/backends/vulkan/runtime/graph/ops/glsl/full.yaml b/backends/vulkan/runtime/graph/ops/glsl/full.yaml index eff78a7938d..1a5b0cb235e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/full.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/full.yaml @@ -14,5 +14,6 @@ full: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 shader_variants: - NAME: full diff --git a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl index d35492bc367..86a2229c416 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl @@ -42,7 +42,8 @@ layout(constant_id = 5) const int group_dim = 1; // work group will write into its assigned element in the shared array. #define MAX_NTHREADS 16 -shared vec4 shared_vecs[MAX_NTHREADS]; +shared vec4 shared_max[MAX_NTHREADS]; +shared vec4 shared_sum[MAX_NTHREADS]; #include "indexing_utils.h" @@ -102,13 +103,13 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { max_elements = max(max_elements, load_texel(tin, scan_pos)); } - shared_vecs[smi] = max_elements; + shared_max[smi] = max_elements; barrier(); // Iterate over the partial maximums to obtain the overall maximum group_i = tid.y * NWORKERS; - max_elements = shared_vecs[group_i++]; + max_elements = shared_max[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - max_elements = max(max_elements, shared_vecs[group_i]); + max_elements = max(max_elements, shared_max[group_i]); } scan_pos[reduce_dim] = tid.x; @@ -118,13 +119,13 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { denominators += exp(load_texel(tin, scan_pos) - max_elements); } - shared_vecs[smi] = denominators; + shared_sum[smi] = denominators; barrier(); // Iterate over the partial sums to obtain the overall sum group_i = tid.y * NWORKERS; - denominators = shared_vecs[group_i++]; + denominators = shared_sum[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - denominators += shared_vecs[group_i]; + denominators += shared_sum[group_i]; } // Determine if there are any padding elements in the final texel of the @@ -184,13 +185,13 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { max_elements.x = max(intex[i], max_elements.x); } } - shared_vecs[smi] = max_elements; + shared_max[smi] = max_elements; barrier(); // Iterate over the partial maximums to obtain the overall maximum group_i = tid.y * NWORKERS; - max_elements = shared_vecs[group_i++]; + max_elements = shared_max[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - max_elements = max(max_elements, shared_vecs[group_i]); + max_elements = max(max_elements, shared_max[group_i]); } // Each element of the texel is itself a partial maximum; iterate over the // texel to find the actual maximum @@ -214,13 +215,13 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { denominators.x += exp(intex[i] - max_element); } } - shared_vecs[smi] = denominators; + shared_sum[smi] = denominators; barrier(); // Iterate over the partial sums to obtain the overall sum group_i = tid.y * NWORKERS; - denominators = shared_vecs[group_i++]; + denominators = shared_sum[group_i++]; for (int i = 1; i < NWORKERS; ++i, group_i++) { - denominators += shared_vecs[group_i]; + denominators += shared_sum[group_i]; } // Reduce over the accumulated texel to find the overall sum float denominator = 0; diff --git a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp index 757afd06849..a6dd8f07f53 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp @@ -19,6 +19,18 @@ namespace vkcompute { +void resize_batch_norm_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + + // For batch norm, output dimensions are the same as input dimensions + std::vector new_out_sizes = graph->sizes_of(self); + graph->virtual_resize(out, new_out_sizes); +} + ValueRef check_and_prepack_arg( ComputeGraph& graph, ValueRef arg_ref, @@ -101,7 +113,7 @@ void add_native_batch_norm_node( // Resize Args {}, // Resizing Logic - nullptr)); + resize_batch_norm_node)); } void native_batch_norm(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index 9ac4c963bc3..329620e80e6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -109,11 +109,15 @@ void add_permute_node( { IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims); const int32_t permute_ndim = - utils::safe_downcast(permute_dims_ptr->size()); + utils::safe_downcast(permute_dims_ptr->size()); for (int32_t nchw_i = permute_ndim - 1, whcn_i = 0; nchw_i >= 0; nchw_i--, whcn_i++) { - const int32_t permute_dim_nchw = permute_dims_ptr->at(nchw_i); + int32_t permute_dim_nchw = + utils::safe_downcast(permute_dims_ptr->at(nchw_i)); + if (permute_dim_nchw < 0) { + permute_dim_nchw += permute_ndim; + } const int32_t permute_dim_whcn = permute_ndim - 1 - permute_dim_nchw; whcn_permute_dims[whcn_i] = permute_dim_whcn; diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index 250fcdd5490..879f59667d6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -137,7 +137,7 @@ void max_pool2d(ComputeGraph& graph, const std::vector& args) { struct DivisorParams final { int32_t divisor_override; - bool count_include_pad; + int32_t count_include_pad; }; DivisorParams create_divisor_params( @@ -148,7 +148,7 @@ DivisorParams create_divisor_params( graph.val_is_int(divisor_override) ? static_cast(graph.get_int(divisor_override)) : 0, - graph.get_bool(count_include_pad)}; + int32_t(graph.get_bool(count_include_pad))}; } void add_avg_pool2d_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp index 13801b45cc7..e2b73b2f3f2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp @@ -32,8 +32,13 @@ void add_squeeze_copy_dims_node( // 2. Squeeze outter most dim // For these cases, just pass input to output via clone. for (int i = 0; i < dims.size(); ++i) { - if (dims.at(i) != 0 && in_sizes.at(dims.at(i)) == 1) { - squeeze_dims.push_back(dims.at(i)); + // adjust negative dims + int64_t dim_val = dims.at(i); + if (dim_val < 0) { + dim_val += in_dim; + } + if (dims.at(i) != 0 && in_sizes.at(dim_val) == 1) { + squeeze_dims.push_back(dim_val); } } if (squeeze_dims.size() == 0) { diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index 53fad86f90c..ee296a4f68f 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -34,7 +34,6 @@ python_unittest( deps = [ "//caffe2:torch", "//executorch/backends/vulkan/_passes:vulkan_passes", - "//executorch/backends/vulkan/quantizer:vulkan_quantizer", "//executorch/backends/vulkan:vulkan_preprocess", "//pytorch/ao:torchao", # @manual ] diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 4a30ab6c2de..438126a179f 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -3,15 +3,8 @@ import torch -from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform -from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass -from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( - get_symmetric_quantization_config, - VulkanQuantizer, -) - from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge from executorch.exir.backend.canonical_partitioners.config_partitioner import ( @@ -94,66 +87,6 @@ def op_node_count(graph_module: torch.fx.GraphModule, canonical_op_name: str) -> class TestVulkanPasses(unittest.TestCase): - def test_fuse_int8pack_mm(self): - K = 256 - N = 256 - model = SingleLinearModule(K, N) - sample_inputs = model.get_sample_inputs() - - quantizer = VulkanQuantizer() - quantizer.set_global( - get_symmetric_quantization_config(is_dynamic=False, weight_bits=8) - ) - - edge_manager = quantize_and_lower_module( - model, - sample_inputs, - quantizer, - ) - - ep = edge_manager._edge_programs["forward"] - edge_manager.transform( - [ - AddmmToLinearTransform(), - FuseQuantizedOpsTransform(ep), - ] - ) - - gm = ep.graph_module - - self.assertEqual(op_node_count(gm, "_weight_int8pack_mm.default"), 1) - self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) - - def test_fuse_linear_qcs4w(self): - K = 256 - N = 256 - model = SingleLinearModule(K, N) - sample_inputs = model.get_sample_inputs() - - quantizer = VulkanQuantizer() - quantizer.set_global( - get_symmetric_quantization_config(is_dynamic=False, weight_bits=4) - ) - - edge_manager = quantize_and_lower_module( - model, - sample_inputs, - quantizer, - ) - - ep = edge_manager._edge_programs["forward"] - edge_manager.transform( - [ - AddmmToLinearTransform(), - FuseQuantizedOpsTransform(ep), - ] - ) - - gm = ep.graph_module - - self.assertEqual(op_node_count(gm, "linear_qcs4w.default"), 1) - self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) - def test_fuse_rotary_emb(self): """Test conversion of rotary embedding pattern to et_vk.apply_rotary_emb custom op.""" @@ -238,7 +171,8 @@ def _reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor): # Apply the rotary embedding pass ep = edge_manager._edge_programs["forward"] - rotary_pass = FusePatternsPass(ep) + rotary_pass = FusePatternsPass() + rotary_pass._exported_program = ep result = rotary_pass.call(ep.graph_module) # Verify that the pass was successful diff --git a/backends/vulkan/test/utils.py b/backends/vulkan/test/utils.py index bfe4e9fceee..a887c53473a 100644 --- a/backends/vulkan/test/utils.py +++ b/backends/vulkan/test/utils.py @@ -90,7 +90,9 @@ def export_model_to_vulkan( qmode=QuantizationMode.NONE, ): compile_options = {} - exported_graph = get_exported_graph(model, sample_inputs, qmode=qmode) + exported_graph = get_exported_graph( + model, sample_inputs, dynamic_shapes=dynamic_shapes, qmode=qmode + ) program = export( exported_graph, sample_inputs, diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 972a4f26c1b..09c57f649ae 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -128,7 +128,7 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: is_get_attr_node(node) or is_param(program, node) or is_buffer(program, node) - or is_constant(program, node) + or is_lifted_tensor_constant(program, node) ) @@ -206,6 +206,8 @@ def is_tensor_arg_node(node: Any) -> bool: if isinstance(node, torch.fx.Node): return is_tensor_node(node) elif isinstance(node, (list, tuple)): + if len(node) == 0: + return False return all(is_tensor_node(n) for n in node) return False @@ -1228,6 +1230,16 @@ def is_in_8bit_range(tensor: torch.Tensor) -> bool: ## +def nchw_dim_to_whcn_dim(nchw_dim: int, ndim: int) -> int: + # Handle negative indices for nchw_dim + if nchw_dim < 0: + nchw_dim += ndim + + assert nchw_dim >= 0 and nchw_dim < ndim + whcn_dim = (ndim - 1) - nchw_dim + return whcn_dim + + def get_tensor_val_str(tensor_val: FakeTensor) -> str: return f"{tensor_val.dtype}: {tensor_val.shape}" @@ -1279,6 +1291,7 @@ def update_program_state_dict( updated_tensor: torch.Tensor, ) -> None: target_name = None + kind = None # Iterate over all the tensors in the graph signature, and find # the one corresponding to the parameter/buffer name for input_ in program.graph_signature.input_specs: @@ -1287,6 +1300,7 @@ def update_program_state_dict( and isinstance(input_.arg, TensorArgument) and input_.arg.name == buffer_name ): + kind = input_.kind target_name = input_.target break @@ -1296,6 +1310,9 @@ def update_program_state_dict( ), f"could not find {buffer_name} in source program signature" assert target_name in program.state_dict, f"could not find {target_name}" + if kind == InputKind.PARAMETER: + updated_tensor = torch.nn.Parameter(updated_tensor, requires_grad=False) + # Finally, overwrite the current tensor with updated tensor program.state_dict[target_name] = updated_tensor diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 2f91d97ff58..876f7fa8900 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -8,7 +8,7 @@ from functools import partial -from typing import Any, Dict, final, List +from typing import Any, Callable, Dict, final, List import executorch.backends.vulkan.utils as utils @@ -56,7 +56,9 @@ from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from executorch.exir.program._program import _copy_module +from executorch.exir.program._program import _transform + +from torch._export.verifier import Verifier from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, @@ -65,28 +67,34 @@ DEFAULT_DEBUG_HANDLE = 65535 +class _any_op(Verifier): + # Set training dialect to skip functional check in base verifier + dialect = "TRAINING" + + def allowed_op_types(self): + return (Callable,) + + # pyre-ignore def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: for p in passes: - if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): - new_gm = program.graph_module - # This is a workaround to allow the memory planning pass to work without - # having to first apply ToOutVarPass(). See the `greedy()` function in - # `exir.memory_planning`; if this attribute isn't set, assertions in - # `collect_spec_from_nodes()` will fail. - if isinstance(p, MemoryPlanningPass): - new_gm.encounter_to_out_var_failure = True - - new_gm_res = p(new_gm) - assert new_gm_res is not None - new_gm = new_gm_res.graph_module - + if isinstance(p, MemoryPlanningPass) and hasattr(p, "run"): + p.run(program.graph_module) + + elif issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): + # Some passes require the ep to be provided. However, since the ep may be + # updated with each pass applied, the ep must be set right before calling + # the pass. _exported_program is the attribute used by XNNPACK and Vulkan + # passes to store the exported program. + if hasattr(p, "_exported_program"): + p._exported_program = program + + program = _transform(program, p, override_verifiers=[_any_op]) # See the application of this function in exir/program/_program.py for more # details on why this step is necessary. if isinstance(p, SpecPropPass): - p.update_placeholder_tensor_specs(program, new_gm) + p.update_placeholder_tensor_specs(program, program.graph_module) - _copy_module(program.graph_module, new_gm) else: program = p(program) @@ -159,17 +167,17 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ - FusePatternsPass(program), - RemoveRedundantOpsTransform(), + FuseBatchNormPass(program), + FusePatternsPass(), + FuseClampPass(), AddmmToLinearTransform(), - FuseQuantizedOpsTransform(program), + RemoveRedundantOpsTransform(), + FuseQuantizedOpsTransform(), ReplaceQDQPass(), - FoldQDQPass(program), + FoldQDQPass(), SqueezeUnsqueezeInputs(), FuseViewCopyTransform(), ViewCopyToSqueezeUnsqueezePass(), - FuseBatchNormPass(program), - FuseClampPass(), ], ) @@ -215,6 +223,11 @@ def preprocess( # noqa: C901 mem_planning_suite = MemoryPlanningAlgorithmSuite( algo_list=[greedy_memory_planning] ) + # This is a workaround to allow the memory planning pass to work without having + # to first apply ToOutVarPass(). See the `greedy()` function in + # `exir.memory_planning`; if this attribute isn't set, assertions in + # `collect_spec_from_nodes()` will fail. + program.graph_module.encounter_to_out_var_failure = True program = apply_passes( program, [ diff --git a/examples/vulkan/export.py b/examples/vulkan/export.py index c90b501df6f..dace37e5473 100644 --- a/examples/vulkan/export.py +++ b/examples/vulkan/export.py @@ -14,22 +14,18 @@ import backends.vulkan.test.utils as test_utils import torch +import torchvision -from executorch.backends.transforms.convert_dtype_pass import I64toI32 from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner from executorch.devtools import BundledProgram from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.devtools.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, ) -from executorch.exir import ( - EdgeCompileConfig, - ExecutorchBackendConfig, - to_edge_transform_and_lower, -) +from executorch.exir import to_edge_transform_and_lower from executorch.extension.export_util.utils import save_pte_program from executorch.extension.pytree import tree_flatten -from torch.export import export +from torch.export import Dim, export from ..models import MODEL_NAME_TO_MODEL from ..models.model_factory import EagerModelFactory @@ -38,6 +34,67 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) +def is_vision_model(model_name): + if model_name in [ + # These models are also registered in examples/models + "dl3", + "edsr", + "mv2", + "mv3", + "vit", + "ic3", + "ic4", + "resnet18", + "resnet50", + # These models are not registered in examples/models but are available via + # torchvision + "convnext_small", + "densenet161", + "shufflenet_v2_x1_0", + ]: + return True + + return False + + +def get_vision_model_sample_input(): + return (torch.randn(1, 3, 224, 224),) + + +def get_vision_model_dynamic_shapes(): + return ( + { + 2: Dim("height", min=1, max=16) * 16, + 3: Dim("width", min=1, max=16) * 16, + }, + ) + + +def init_model(model_name): + if model_name == "convnext_small": + return torchvision.models.convnext_small() + if model_name == "densenet161": + return torchvision.models.densenet161() + if model_name == "shufflenet_v2_x1_0": + return torchvision.models.shufflenet_v2_x1_0() + + return None + + +def get_sample_inputs(model_name): + if is_vision_model(model_name): + return get_vision_model_sample_input() + + return None + + +def get_dynamic_shapes(model_name): + if is_vision_model(model_name): + return get_vision_model_dynamic_shapes() + + return None + + def main() -> None: logger = logging.getLogger("") logger.setLevel(logging.INFO) @@ -68,21 +125,6 @@ def main() -> None: help="whether to export with strict mode. Default is True", ) - parser.add_argument( - "-a", - "--segment_alignment", - required=False, - help="specify segment alignment in hex. Default is 0x1000. Use 0x4000 for iOS", - ) - - parser.add_argument( - "-e", - "--external_constants", - action=argparse.BooleanOptionalAction, - default=False, - help="Save constants in external .ptd file. Default is False", - ) - parser.add_argument( "-d", "--dynamic", @@ -119,31 +161,35 @@ def main() -> None: args = parser.parse_args() - if args.model_name not in MODEL_NAME_TO_MODEL: - raise RuntimeError( - f"Model {args.model_name} is not a valid name. " - f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." + if args.model_name in MODEL_NAME_TO_MODEL: + model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[args.model_name] ) + else: + model = init_model(args.model_name) + example_inputs = get_sample_inputs(args.model_name) + dynamic_shapes = get_dynamic_shapes(args.model_name) if args.dynamic else None - model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( - *MODEL_NAME_TO_MODEL[args.model_name] - ) + if model is None: + raise RuntimeError( + f"Model {args.model_name} is not a valid name. " + f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." + ) # Prepare model model.eval() # Setup compile options compile_options = {} - if args.dynamic or dynamic_shapes is not None: + if args.dynamic: compile_options["require_dynamic_shapes"] = True + # Try to manually get the dynamic shapes for the model if not set + if dynamic_shapes is None: + dynamic_shapes = get_dynamic_shapes(args.model_name) + if args.force_fp16: compile_options["force_fp16"] = True - # Configure Edge compilation - edge_compile_config = EdgeCompileConfig( - _skip_dim_order=False, # Proper handling for Vulkan memory format - ) - logging.info(f"Exporting model {args.model_name} with Vulkan delegate") # Export the model using torch.export @@ -157,10 +203,6 @@ def main() -> None: # Transform and lower with Vulkan partitioner edge_program = to_edge_transform_and_lower( program, - compile_config=edge_compile_config, - transform_passes=[ - I64toI32(edge_compile_config._skip_dim_order), - ], partitioner=[VulkanPartitioner(compile_options)], generate_etrecord=args.etrecord, ) @@ -169,13 +211,8 @@ def main() -> None: f"Exported and lowered graph:\n{edge_program.exported_program().graph}" ) - # Configure backend options - backend_config = ExecutorchBackendConfig(external_constants=args.external_constants) - if args.segment_alignment is not None: - backend_config.segment_alignment = int(args.segment_alignment, 16) - # Create executorch program - exec_prog = edge_program.to_executorch(config=backend_config) + exec_prog = edge_program.to_executorch() # Save ETRecord if requested if args.etrecord: