Skip to content

Commit ef0877b

Browse files
committed
[ET-VK][AOT] Define pass application order
## Changes The goal of this diff is to enforce a specific structure in how graph transform passes are applied during `vulkan_preprocess`. This will help make sure that certain passes are applied at the correct time, and that pre-requisite conditions for passes are fulfilled before they are applied. See the comments in `vulkan_preprocess.py` for more details. Differential Revision: [D65234843](https://our.internmc.facebook.com/intern/diff/D65234843/) [ghstack-poisoned]
1 parent fbb0acf commit ef0877b

File tree

2 files changed

+77
-31
lines changed

2 files changed

+77
-31
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
from copy import deepcopy
10+
911
import executorch.backends.vulkan.custom_ops_lib # noqa
1012

1113
import torch
@@ -69,9 +71,14 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
6971
exir_ops.edge.et_vk.prepack.default,
7072
(node,),
7173
)
72-
prepack_node.meta["spec"] = node.meta["spec"]
74+
# This pass assumes that the SpecPropPass() has already been applied, and
75+
# validate that the original node is marked as a constant. Constant tensors
76+
# do not participate in memory planning.
77+
assert "spec" in node.meta
78+
assert node.meta["spec"].const
79+
prepack_node.meta["spec"] = deepcopy(node.meta["spec"])
7380
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
74-
# memory object. This pass must be executed AFTER the memory planning pass.
81+
# memory object.
7582
prepack_node.meta["spec"].mem_obj_id = -1
7683
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
7784

backends/vulkan/vulkan_preprocess.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
1818
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
1919

20-
from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform
21-
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
20+
from executorch.backends.vulkan._passes import (
21+
insert_prepack_nodes,
22+
RemoveLocalScalarDenseOpsTransform,
23+
)
2224

2325
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
2426
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
@@ -32,6 +34,7 @@
3234
PreprocessResult,
3335
)
3436
from executorch.exir.backend.utils import DelegateMappingBuilder
37+
from executorch.exir.pass_base import ExportPass, PassBase
3538

3639
from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
3740

@@ -46,6 +49,35 @@
4649
DEFAULT_DEBUG_HANDLE = 65535
4750

4851

52+
# pyre-ignore
53+
def apply_passes(program: ExportedProgram, passes) -> ExportedProgram:
54+
for pass_ in passes:
55+
56+
if issubclass(type(pass_), ExportPass) or issubclass(type(pass_), PassBase):
57+
new_gm = program.graph_module
58+
# This is a workaround to allow the memory planning pass to work without
59+
# having to first apply ToOutVarPass(). See the `greedy()` function in
60+
# `exir.memory_planning`; if this attribute isn't set, assertions in
61+
# `collect_spec_from_nodes()` will fail.
62+
if isinstance(pass_, MemoryPlanningPass):
63+
new_gm.encounter_to_out_var_failure = True
64+
65+
new_gm_res = pass_(new_gm)
66+
assert new_gm_res is not None
67+
new_gm = new_gm_res.graph_module
68+
69+
# See the application of this function in exir/program/_program.py for more
70+
# details on why this step is necessary.
71+
if isinstance(pass_, SpecPropPass):
72+
pass_.update_placeholder_tensor_specs(program, new_gm)
73+
74+
_copy_module(program.graph_module, new_gm)
75+
else:
76+
program = pass_(program)
77+
78+
return program
79+
80+
4981
@final
5082
class VulkanBackend(BackendDetails):
5183
@classmethod
@@ -57,35 +89,42 @@ def preprocess( # noqa: C901
5789
) -> PreprocessResult:
5890
program = unsafe_remove_auto_functionalized_pass(program)
5991

60-
passes = [
61-
RemoveCloneOpsTransform(),
62-
AddmmToLinearTransform(),
63-
FuseDequantLinearPass(),
64-
FuseViewCopyTransform(),
65-
FuseBatchNormWithConvPass(program),
66-
FuseClampPass(),
67-
SpecPropPass(),
68-
ConstraintBasedSymShapeEvalPass(),
69-
RemoveLocalScalarDenseOpsTransform(),
70-
MemoryPlanningPass(),
71-
]
72-
73-
new_gm = program.graph_module
74-
75-
for p in passes:
76-
# This is a workaround to allow the memory planning pass to work without
77-
# having to first apply ToOutVarPass(). See the `greedy()` function in
78-
# `exir.memory_planning`; if this attribute isn't set, assertions in
79-
# `collect_spec_from_nodes()` will fail.
80-
if isinstance(p, MemoryPlanningPass):
81-
new_gm.encounter_to_out_var_failure = True
82-
new_gm_res = p(new_gm)
83-
assert new_gm_res is not None
84-
new_gm = new_gm_res.graph_module
92+
# First, apply passes that fuse/remove operators to consolidate the graph
93+
# structure but still preserve an "ATen-compliant" graph structure.
94+
program = apply_passes(
95+
program,
96+
[
97+
RemoveCloneOpsTransform(),
98+
AddmmToLinearTransform(),
99+
FuseDequantLinearPass(),
100+
FuseViewCopyTransform(),
101+
FuseBatchNormWithConvPass(program),
102+
FuseClampPass(),
103+
],
104+
)
85105

86-
_copy_module(program.graph_module, new_gm)
106+
# Next annotate tensor nodes with TensorSpec structs which is needed for dynamic
107+
# shapes and memory planning. Until this point, the graph must be "ATen compliant"
108+
# (e.g. all arguments to aten operators must match the ATen function schema).
109+
program = apply_passes(program, [SpecPropPass()])
110+
111+
# Apply graph transforms which either require `TensorSpec`s to have been created
112+
# or would create an non ATen compliant graph structure.
113+
program = apply_passes(
114+
program,
115+
[
116+
# Since this pass may replace a scalar argument with a tensor argument,
117+
# this pass may result in a non ATen compliant graph structure.
118+
RemoveLocalScalarDenseOpsTransform(),
119+
insert_prepack_nodes,
120+
],
121+
)
87122

88-
program = insert_prepack_nodes(program)
123+
# Finally, apply dynamic shape passes and memory planning pass. These passes
124+
# must be applied only when the graph structure is finalized.
125+
program = apply_passes(
126+
program, [ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass()]
127+
)
89128

90129
graph_builder = VkGraphBuilder(
91130
program, DelegateMappingBuilder(generated_identifiers=True)

0 commit comments

Comments
 (0)