1717from executorch .backends .transforms .fuse_view_copy import FuseViewCopyTransform
1818from executorch .backends .transforms .remove_clone_ops import RemoveCloneOpsTransform
1919
20- from executorch .backends .vulkan ._passes import RemoveLocalScalarDenseOpsTransform
21- from executorch .backends .vulkan ._passes .insert_prepack_nodes import insert_prepack_nodes
20+ from executorch .backends .vulkan ._passes import (
21+ insert_prepack_nodes ,
22+ RemoveLocalScalarDenseOpsTransform ,
23+ )
2224
2325from executorch .backends .vulkan .serialization .vulkan_graph_builder import VkGraphBuilder
2426from executorch .backends .vulkan .serialization .vulkan_graph_serialize import (
3234 PreprocessResult ,
3335)
3436from executorch .exir .backend .utils import DelegateMappingBuilder
37+ from executorch .exir .pass_base import ExportPass , PassBase
3538
3639from executorch .exir .passes import MemoryPlanningPass , SpecPropPass
3740
4649DEFAULT_DEBUG_HANDLE = 65535
4750
4851
52+ # pyre-ignore
53+ def apply_passes (program : ExportedProgram , passes ) -> ExportedProgram :
54+ for pass_ in passes :
55+
56+ if issubclass (type (pass_ ), ExportPass ) or issubclass (type (pass_ ), PassBase ):
57+ new_gm = program .graph_module
58+ # This is a workaround to allow the memory planning pass to work without
59+ # having to first apply ToOutVarPass(). See the `greedy()` function in
60+ # `exir.memory_planning`; if this attribute isn't set, assertions in
61+ # `collect_spec_from_nodes()` will fail.
62+ if isinstance (pass_ , MemoryPlanningPass ):
63+ new_gm .encounter_to_out_var_failure = True
64+
65+ new_gm_res = pass_ (new_gm )
66+ assert new_gm_res is not None
67+ new_gm = new_gm_res .graph_module
68+
69+ # See the application of this function in exir/program/_program.py for more
70+ # details on why this step is necessary.
71+ if isinstance (pass_ , SpecPropPass ):
72+ pass_ .update_placeholder_tensor_specs (program , new_gm )
73+
74+ _copy_module (program .graph_module , new_gm )
75+ else :
76+ program = pass_ (program )
77+
78+ return program
79+
80+
4981@final
5082class VulkanBackend (BackendDetails ):
5183 @classmethod
@@ -57,35 +89,42 @@ def preprocess( # noqa: C901
5789 ) -> PreprocessResult :
5890 program = unsafe_remove_auto_functionalized_pass (program )
5991
60- passes = [
61- RemoveCloneOpsTransform (),
62- AddmmToLinearTransform (),
63- FuseDequantLinearPass (),
64- FuseViewCopyTransform (),
65- FuseBatchNormWithConvPass (program ),
66- FuseClampPass (),
67- SpecPropPass (),
68- ConstraintBasedSymShapeEvalPass (),
69- RemoveLocalScalarDenseOpsTransform (),
70- MemoryPlanningPass (),
71- ]
72-
73- new_gm = program .graph_module
74-
75- for p in passes :
76- # This is a workaround to allow the memory planning pass to work without
77- # having to first apply ToOutVarPass(). See the `greedy()` function in
78- # `exir.memory_planning`; if this attribute isn't set, assertions in
79- # `collect_spec_from_nodes()` will fail.
80- if isinstance (p , MemoryPlanningPass ):
81- new_gm .encounter_to_out_var_failure = True
82- new_gm_res = p (new_gm )
83- assert new_gm_res is not None
84- new_gm = new_gm_res .graph_module
92+ # First, apply passes that fuse/remove operators to consolidate the graph
93+ # structure but still preserve an "ATen-compliant" graph structure.
94+ program = apply_passes (
95+ program ,
96+ [
97+ RemoveCloneOpsTransform (),
98+ AddmmToLinearTransform (),
99+ FuseDequantLinearPass (),
100+ FuseViewCopyTransform (),
101+ FuseBatchNormWithConvPass (program ),
102+ FuseClampPass (),
103+ ],
104+ )
85105
86- _copy_module (program .graph_module , new_gm )
106+ # Next annotate tensor nodes with TensorSpec structs which is needed for dynamic
107+ # shapes and memory planning. Until this point, the graph must be "ATen compliant"
108+ # (e.g. all arguments to aten operators must match the ATen function schema).
109+ program = apply_passes (program , [SpecPropPass ()])
110+
111+ # Apply graph transforms which either require `TensorSpec`s to have been created
112+ # or would create an non ATen compliant graph structure.
113+ program = apply_passes (
114+ program ,
115+ [
116+ # Since this pass may replace a scalar argument with a tensor argument,
117+ # this pass may result in a non ATen compliant graph structure.
118+ RemoveLocalScalarDenseOpsTransform (),
119+ insert_prepack_nodes ,
120+ ],
121+ )
87122
88- program = insert_prepack_nodes (program )
123+ # Finally, apply dynamic shape passes and memory planning pass. These passes
124+ # must be applied only when the graph structure is finalized.
125+ program = apply_passes (
126+ program , [ConstraintBasedSymShapeEvalPass (), MemoryPlanningPass ()]
127+ )
89128
90129 graph_builder = VkGraphBuilder (
91130 program , DelegateMappingBuilder (generated_identifiers = True )
0 commit comments