Skip to content

Commit 6681329

Browse files
committed
Update on "[ET-VK][AOT] Define pass application order"
## Changes The goal of this diff is to enforce a specific structure in how graph transform passes are applied during `vulkan_preprocess`. This will help make sure that certain passes are applied at the correct time, and that pre-requisite conditions for passes are fulfilled before they are applied. See the comments in `vulkan_preprocess.py` for more details. Differential Revision: [D65234843](https://our.internmc.facebook.com/intern/diff/D65234843/) [ghstack-poisoned]
1 parent ef0877b commit 6681329

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
7171
exir_ops.edge.et_vk.prepack.default,
7272
(node,),
7373
)
74-
# This pass assumes that the SpecPropPass() has already been applied, and
75-
# validate that the original node is marked as a constant. Constant tensors
76-
# do not participate in memory planning.
74+
# This pass assumes that the SpecPropPass() has already been applied
7775
assert "spec" in node.meta
76+
# Validate that the original node is marked as a constant. Constant tensors
77+
# do not participate in memory planning.
7878
assert node.meta["spec"].const
7979
prepack_node.meta["spec"] = deepcopy(node.meta["spec"])
8080
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated

backends/vulkan/vulkan_preprocess.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def preprocess( # noqa: C901
9090
program = unsafe_remove_auto_functionalized_pass(program)
9191

9292
# First, apply passes that fuse/remove operators to consolidate the graph
93-
# structure but still preserve an "ATen-compliant" graph structure.
93+
# structure but still preserve an "ATen-compliant" graph structure (i.e. all
94+
# arguments to aten operators must match the ATen function schema).
9495
program = apply_passes(
9596
program,
9697
[
@@ -104,8 +105,9 @@ def preprocess( # noqa: C901
104105
)
105106

106107
# Next annotate tensor nodes with TensorSpec structs which is needed for dynamic
107-
# shapes and memory planning. Until this point, the graph must be "ATen compliant"
108-
# (e.g. all arguments to aten operators must match the ATen function schema).
108+
# shapes and memory planning. Until this point, the graph must be ATen compliant
109+
# because SpecPropPass will be calling the underlying ATen operators during its
110+
# execution.
109111
program = apply_passes(program, [SpecPropPass()])
110112

111113
# Apply graph transforms which either require `TensorSpec`s to have been created

0 commit comments

Comments
 (0)