1717from executorch .backends .transforms .fuse_view_copy import FuseViewCopyTransform
1818from executorch .backends .transforms .remove_clone_ops import RemoveCloneOpsTransform
1919
20- from executorch .backends .vulkan ._passes import RemoveLocalScalarDenseOpsTransform
21- from executorch .backends .vulkan ._passes .insert_prepack_nodes import insert_prepack_nodes
20+ from executorch .backends .vulkan ._passes import (
21+ insert_prepack_nodes ,
22+ RemoveLocalScalarDenseOpsTransform ,
23+ )
2224
2325from executorch .backends .vulkan .serialization .vulkan_graph_builder import VkGraphBuilder
2426from executorch .backends .vulkan .serialization .vulkan_graph_serialize import (
3234 PreprocessResult ,
3335)
3436from executorch .exir .backend .utils import DelegateMappingBuilder
37+ from executorch .exir .pass_base import ExportPass , PassBase
3538
3639from executorch .exir .passes import MemoryPlanningPass , SpecPropPass
3740
4649DEFAULT_DEBUG_HANDLE = 65535
4750
4851
52+ # pyre-ignore
53+ def apply_passes (program : ExportedProgram , passes ) -> ExportedProgram :
54+ for p in passes :
55+
56+ if issubclass (type (p ), ExportPass ) or issubclass (type (p ), PassBase ):
57+ new_gm = program .graph_module
58+ # This is a workaround to allow the memory planning pass to work without
59+ # having to first apply ToOutVarPass(). See the `greedy()` function in
60+ # `exir.memory_planning`; if this attribute isn't set, assertions in
61+ # `collect_spec_from_nodes()` will fail.
62+ if isinstance (p , MemoryPlanningPass ):
63+ new_gm .encounter_to_out_var_failure = True
64+
65+ new_gm_res = p (new_gm )
66+ assert new_gm_res is not None
67+ new_gm = new_gm_res .graph_module
68+
69+ # See the application of this function in exir/program/_program.py for more
70+ # details on why this step is necessary.
71+ if isinstance (p , SpecPropPass ):
72+ p .update_placeholder_tensor_specs (program , new_gm )
73+
74+ _copy_module (program .graph_module , new_gm )
75+ else :
76+ program = p (program )
77+
78+ return program
79+
80+
4981@final
5082class VulkanBackend (BackendDetails ):
5183 @classmethod
@@ -57,35 +89,44 @@ def preprocess( # noqa: C901
5789 ) -> PreprocessResult :
5890 program = unsafe_remove_auto_functionalized_pass (program )
5991
60- passes = [
61- RemoveCloneOpsTransform (),
62- AddmmToLinearTransform (),
63- FuseDequantLinearPass (),
64- FuseViewCopyTransform (),
65- FuseBatchNormWithConvPass (program ),
66- FuseClampPass (),
67- SpecPropPass (),
68- ConstraintBasedSymShapeEvalPass (),
69- RemoveLocalScalarDenseOpsTransform (),
70- MemoryPlanningPass (),
71- ]
72-
73- new_gm = program .graph_module
74-
75- for p in passes :
76- # This is a workaround to allow the memory planning pass to work without
77- # having to first apply ToOutVarPass(). See the `greedy()` function in
78- # `exir.memory_planning`; if this attribute isn't set, assertions in
79- # `collect_spec_from_nodes()` will fail.
80- if isinstance (p , MemoryPlanningPass ):
81- new_gm .encounter_to_out_var_failure = True
82- new_gm_res = p (new_gm )
83- assert new_gm_res is not None
84- new_gm = new_gm_res .graph_module
92+ # First, apply passes that fuse/remove operators to consolidate the graph
93+ # structure but still preserve an "ATen-compliant" graph structure (i.e. all
94+ # arguments to ATen operators must match the ATen function schema).
95+ program = apply_passes (
96+ program ,
97+ [
98+ RemoveCloneOpsTransform (),
99+ AddmmToLinearTransform (),
100+ FuseDequantLinearPass (),
101+ FuseViewCopyTransform (),
102+ FuseBatchNormWithConvPass (program ),
103+ FuseClampPass (),
104+ ],
105+ )
85106
86- _copy_module (program .graph_module , new_gm )
107+ # Next annotate tensor nodes with TensorSpec structs which is needed for dynamic
108+ # shapes and memory planning. Until this point, the graph must be ATen compliant
109+ # because SpecPropPass will be calling the underlying ATen operators during its
110+ # execution.
111+ program = apply_passes (program , [SpecPropPass ()])
112+
113+ # Apply graph transforms which either require `TensorSpec`s to have been created
114+ # or would create an non ATen compliant graph structure.
115+ program = apply_passes (
116+ program ,
117+ [
118+ # Since this pass may replace a scalar argument with a tensor argument,
119+ # this pass may result in a non ATen compliant graph structure.
120+ RemoveLocalScalarDenseOpsTransform (),
121+ insert_prepack_nodes ,
122+ ],
123+ )
87124
88- program = insert_prepack_nodes (program )
125+ # Finally, apply dynamic shape passes and memory planning pass. These passes
126+ # must be applied only when the graph structure is finalized.
127+ program = apply_passes (
128+ program , [ConstraintBasedSymShapeEvalPass (), MemoryPlanningPass ()]
129+ )
89130
90131 graph_builder = VkGraphBuilder (
91132 program , DelegateMappingBuilder (generated_identifiers = True )
0 commit comments