From da51c8141aa4d69aad20987d72d72cd2a9d1b5cb Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 31 Oct 2024 14:20:28 -0700 Subject: [PATCH 1/2] [ET-VK][AOT] Define pass application order Pull Request resolved: https://github.com/pytorch/executorch/pull/6577 ## Changes The goal of this diff is to enforce a specific structure in how graph transform passes are applied during `vulkan_preprocess`. This will help make sure that certain passes are applied at the correct time, and that pre-requisite conditions for passes are fulfilled before they are applied. See the comments in `vulkan_preprocess.py` for more details. ghstack-source-id: 251223076 Differential Revision: [D65234843](https://our.internmc.facebook.com/intern/diff/D65234843/) --- .../vulkan/_passes/insert_prepack_nodes.py | 12 ++- backends/vulkan/vulkan_preprocess.py | 99 +++++++++++++------ 2 files changed, 80 insertions(+), 31 deletions(-) diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index 3f6588d84ad..cafeedbd5da 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -6,6 +6,8 @@ # pyre-strict +from copy import deepcopy + import executorch.backends.vulkan.custom_ops_lib # noqa import torch @@ -69,9 +71,15 @@ def prepack_not_required(node: torch.fx.Node) -> bool: exir_ops.edge.et_vk.prepack.default, (node,), ) - prepack_node.meta["spec"] = node.meta["spec"] + # This pass assumes that the SpecPropPass() has already been applied + assert "spec" in node.meta + # Validate that the original node is marked as a constant. Constant tensors + # do not participate in memory planning. + assert node.meta["spec"].const + prepack_node.meta["val"] = node.meta["val"] + prepack_node.meta["spec"] = deepcopy(node.meta["spec"]) # Set the mem_obj_id to -1 to indicate that this node requires a dedicated - # memory object. This pass must be executed AFTER the memory planning pass. + # memory object. prepack_node.meta["spec"].mem_obj_id = -1 node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 0e116ad2c4c..96eee198f4d 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -17,8 +17,10 @@ from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform -from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform -from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes +from executorch.backends.vulkan._passes import ( + insert_prepack_nodes, + RemoveLocalScalarDenseOpsTransform, +) from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( @@ -32,6 +34,7 @@ PreprocessResult, ) from executorch.exir.backend.utils import DelegateMappingBuilder +from executorch.exir.pass_base import ExportPass, PassBase from executorch.exir.passes import MemoryPlanningPass, SpecPropPass @@ -46,6 +49,35 @@ DEFAULT_DEBUG_HANDLE = 65535 +# pyre-ignore +def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: + for p in passes: + + if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): + new_gm = program.graph_module + # This is a workaround to allow the memory planning pass to work without + # having to first apply ToOutVarPass(). See the `greedy()` function in + # `exir.memory_planning`; if this attribute isn't set, assertions in + # `collect_spec_from_nodes()` will fail. + if isinstance(p, MemoryPlanningPass): + new_gm.encounter_to_out_var_failure = True + + new_gm_res = p(new_gm) + assert new_gm_res is not None + new_gm = new_gm_res.graph_module + + # See the application of this function in exir/program/_program.py for more + # details on why this step is necessary. + if isinstance(p, SpecPropPass): + p.update_placeholder_tensor_specs(program, new_gm) + + _copy_module(program.graph_module, new_gm) + else: + program = p(program) + + return program + + @final class VulkanBackend(BackendDetails): @classmethod @@ -57,35 +89,44 @@ def preprocess( # noqa: C901 ) -> PreprocessResult: program = unsafe_remove_auto_functionalized_pass(program) - passes = [ - RemoveCloneOpsTransform(), - AddmmToLinearTransform(), - FuseDequantLinearPass(), - FuseViewCopyTransform(), - FuseBatchNormWithConvPass(program), - FuseClampPass(), - SpecPropPass(), - ConstraintBasedSymShapeEvalPass(), - RemoveLocalScalarDenseOpsTransform(), - MemoryPlanningPass(), - ] - - new_gm = program.graph_module - - for p in passes: - # This is a workaround to allow the memory planning pass to work without - # having to first apply ToOutVarPass(). See the `greedy()` function in - # `exir.memory_planning`; if this attribute isn't set, assertions in - # `collect_spec_from_nodes()` will fail. - if isinstance(p, MemoryPlanningPass): - new_gm.encounter_to_out_var_failure = True - new_gm_res = p(new_gm) - assert new_gm_res is not None - new_gm = new_gm_res.graph_module + # First, apply passes that fuse/remove operators to consolidate the graph + # structure but still preserve an "ATen-compliant" graph structure (i.e. all + # arguments to ATen operators must match the ATen function schema). + program = apply_passes( + program, + [ + RemoveCloneOpsTransform(), + AddmmToLinearTransform(), + FuseDequantLinearPass(), + FuseViewCopyTransform(), + FuseBatchNormWithConvPass(program), + FuseClampPass(), + ], + ) - _copy_module(program.graph_module, new_gm) + # Next annotate tensor nodes with TensorSpec structs which is needed for dynamic + # shapes and memory planning. Until this point, the graph must be ATen compliant + # because SpecPropPass will be calling the underlying ATen operators during its + # execution. + program = apply_passes(program, [SpecPropPass()]) + + # Apply graph transforms which either require `TensorSpec`s to have been created + # or would create an non ATen compliant graph structure. + program = apply_passes( + program, + [ + # Since this pass may replace a scalar argument with a tensor argument, + # this pass may result in a non ATen compliant graph structure. + RemoveLocalScalarDenseOpsTransform(), + insert_prepack_nodes, + ], + ) - program = insert_prepack_nodes(program) + # Finally, apply dynamic shape passes and memory planning pass. These passes + # must be applied only when the graph structure is finalized. + program = apply_passes( + program, [ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass()] + ) graph_builder = VkGraphBuilder( program, DelegateMappingBuilder(generated_identifiers=True) From 3792c79452489c83fd537755b0aad1e5c74d0318 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 31 Oct 2024 14:20:29 -0700 Subject: [PATCH 2/2] [ET-VK][AOT][ez] Introduce vulkan export utils lib Pull Request resolved: https://github.com/pytorch/executorch/pull/6600 ## Changes As title. Introduce a common Python utility library for scripts in the Vulkan backend. ghstack-source-id: 251223077 Differential Revision: [D65291064](https://our.internmc.facebook.com/intern/diff/D65291064/) --- backends/vulkan/_passes/TARGETS | 1 + .../vulkan/_passes/insert_prepack_nodes.py | 21 +---------- .../serialization/vulkan_graph_builder.py | 37 +++++-------------- backends/vulkan/targets.bzl | 13 +++++++ backends/vulkan/utils.py | 30 +++++++++++++++ 5 files changed, 56 insertions(+), 46 deletions(-) create mode 100644 backends/vulkan/utils.py diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 3f328deb485..cf50f170cf3 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -12,6 +12,7 @@ runtime.python_library( deps = [ "//caffe2:torch", "//executorch/exir:pass_base", + "//executorch/backends/vulkan:utils_lib", ], ) diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index cafeedbd5da..37665a6da8e 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -13,10 +13,10 @@ import torch from executorch.backends.vulkan.op_registry import handles_own_prepacking +from executorch.backends.vulkan.utils import is_param_node from executorch.exir.dialects._ops import ops as exir_ops -from torch._export.utils import is_buffer, is_param from torch.export import ExportedProgram @@ -31,25 +31,8 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: argument into the operator implementation. """ - def is_get_attr_node(node: torch.fx.Node) -> bool: - return isinstance(node, torch.fx.Node) and node.op == "get_attr" - - def is_constant(node: torch.fx.Node) -> bool: - return node.name in program.graph_signature.inputs_to_lifted_tensor_constants - - def is_param_node(node: torch.fx.Node) -> bool: - """ - Check if the given node is a parameter within the exported program - """ - return ( - is_get_attr_node(node) - or is_param(program, node) - or is_buffer(program, node) - or is_constant(node) - ) - def prepack_not_required(node: torch.fx.Node) -> bool: - if not is_param_node(node): + if not is_param_node(program, node): return True for user in node.users: diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index f9ae83ddc68..bc77bc40cfb 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -12,6 +12,11 @@ import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema import torch +from executorch.backends.vulkan.utils import ( + is_constant, + is_get_attr_node, + is_param_node, +) from executorch.exir.backend.utils import DelegateMappingBuilder from executorch.exir.tensor import TensorSpec @@ -68,34 +73,12 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: else: raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") - def is_constant(self, node: Node): - return ( - node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants - ) - - def is_get_attr_node(self, node: Node) -> bool: - """ - Returns true if the given node is a get attr node for a tensor of the model - """ - return isinstance(node, Node) and node.op == "get_attr" - - def is_param_node(self, node: Node) -> bool: - """ - Check if the given node is a parameter within the exported program - """ - return ( - self.is_get_attr_node(node) - or is_param(self.program, node) - or is_buffer(self.program, node) - or self.is_constant(node) - ) - def get_constant(self, node: Node) -> Optional[torch.Tensor]: """ Returns the constant associated with the given node in the exported program. Returns None if the node is not a constant within the exported program """ - if self.is_constant(node): + if is_constant(self.program, node): constant_name = ( self.program.graph_signature.inputs_to_lifted_tensor_constants[ node.name @@ -116,9 +99,9 @@ def get_param_tensor(self, node: Node) -> torch.Tensor: tensor = get_param(self.program, node) elif is_buffer(self.program, node): tensor = get_buffer(self.program, node) - elif self.is_constant(node): + elif is_constant(self.program, node): tensor = self.get_constant(node) - elif self.is_get_attr_node(node): + elif is_get_attr_node(node): # This is a hack to support both lifted and unlifted graph try: tensor = getattr(node.graph.owning_module, node.target) @@ -132,7 +115,7 @@ def get_param_tensor(self, node: Node) -> torch.Tensor: def maybe_add_constant_tensor(self, node: Node) -> int: constant_id = -1 - if self.is_param_node(node): + if is_param_node(self.program, node): constant_id = len(self.const_tensors) self.const_tensors.append(self.get_param_tensor(node)) @@ -280,7 +263,7 @@ def process_placeholder_node(self, node: Node) -> None: if len(node.users) == 0: return None ids = self.create_node_value(node) - if not self.is_param_node(node): + if not is_param_node(self.program, node): if isinstance(ids, int): self.input_ids.append(ids) else: diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 12243ec7fab..0d3b17ccccc 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -213,6 +213,19 @@ def define_common_targets(is_fbcode = False): ## AOT targets ## if is_fbcode: + runtime.python_library( + name = "utils_lib", + srcs = [ + "utils.py", + ], + visibility = [ + "//executorch/backends/vulkan/...", + ], + deps = [ + "//caffe2:torch", + ] + ) + runtime.python_library( name = "custom_ops_lib", srcs = [ diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py new file mode 100644 index 00000000000..ae0b8c69406 --- /dev/null +++ b/backends/vulkan/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch._export.utils import is_buffer, is_param + +from torch.export import ExportedProgram + + +def is_get_attr_node(node: torch.fx.Node) -> bool: + return isinstance(node, torch.fx.Node) and node.op == "get_attr" + + +def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool: + return node.name in program.graph_signature.inputs_to_lifted_tensor_constants + + +def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: + """ + Check if the given node is a parameter within the exported program + """ + return ( + is_get_attr_node(node) + or is_param(program, node) + or is_buffer(program, node) + or is_constant(program, node) + )