Skip to content

Commit 5394756

Browse files
committed
Update on "[ET-VK] Clean up prepack API"
## Context As title, revamp the prepacking API: * Make the naming more clear; i.e. `prepack_if_tensor_ref` to `prepack_standard` to disambiguate the packing logic that will used. * Instead of passing through the `v` argument if it is a `Tensor` by default, this functionality must be toggled via the `passthrough` argument. The goal is to encourage developers to be more explicit about what types they expect the operator arguments to be. * Consolidate API surface and reduce the number of overloads Past the API changes, I have also removed a bunch of unnecessary calls to `prepack_if_tensor_ref` throughout the operator implementations. The most common cases were calling it on an input tensor which is not necessary. ## The "big picture" for prepacking `TensorRef` objects and prepacking are used whenever we are dealing with a Tensor whose data is serialized with the model. However, these "serialized tensors" all belong to one of two categories * Weight/biases: trained weights and biases that act as the state for a i.e. Convolutional or Linear layer. These tensors are used only within the `nn.Module` that they belong to * Persistent tensors: tensors whose data just happen to be invariant to the inputs, and their data can be serialized with the model itself. They are treated as regular tensors and may be used in several operators throughout the model. One example is `freqs_sin` and `freqs_cos` in Llama models which are used to calculate rotary positional encodings For weights and biases, the way that the serialized data should be packed may be dependent on the operator it is used in. However, for persistent tensors they must be packed with the "standard" staging to tensor algorithm since they are the same as regular tensors. While it is well known which operators expect weight tensors. However, persistent tensors are tricky because they can be used as an argument to any operator. This would mean that every operator needs to account for the possibility that one of their inputs will be a serialized tensor. This is undesirable because it adds an additional layer of indirection when processing operator inputs on top of the fact that every argument is actually a reference to a`Value` object in the graph, which itself is a wrapper. It also makes things complicated for the operator developer. Another downside is that persistent tensors will be packed multiple times, once by each operator that uses it. To address this, I plan to handle persistent tensors at export time by inserting a `prepack()` operator for them which will cause operators that use the serialized tensor to see a Tensor object instead of a TensorRef object. This will make it so that the only operators that should expect to prepack an argument are tensors that expect a weight argument, and also avoid packing persistent tensors multiple times. Differential Revision: [D64550560](https://our.internmc.facebook.com/intern/diff/D64550560/) [ghstack-poisoned]
2 parents 2835329 + fdba35e commit 5394756

File tree

50 files changed

+1220
-240
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1220
-240
lines changed

.ci/scripts/build_llama_android.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ install_executorch_and_backend_lib() {
1919
cmake -DBUCK2="${BUCK2}" \
2020
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
2121
-DANDROID_ABI="${ANDROID_ABI}" \
22-
-DANDROID_PLATFORM=android-23 \
2322
-DCMAKE_INSTALL_PREFIX=cmake-android-out \
2423
-DCMAKE_BUILD_TYPE=Release \
2524
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
@@ -41,7 +40,6 @@ build_llama_runner() {
4140
cmake -DBUCK2="${BUCK2}" \
4241
-DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK"/build/cmake/android.toolchain.cmake \
4342
-DANDROID_ABI="${ANDROID_ABI}" \
44-
-DANDROID_PLATFORM=android-23 \
4543
-DCMAKE_INSTALL_PREFIX=cmake-android-out \
4644
-DCMAKE_BUILD_TYPE=Release -DPYTHON_EXECUTABLE=python \
4745
-DEXECUTORCH_BUILD_XNNPACK=ON \

.ci/scripts/test_llava.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ cmake_install_executorch_libraries_for_android() {
5656
cmake \
5757
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
5858
-DANDROID_ABI=arm64-v8a \
59-
-DANDROID_PLATFORM=android-23 \
6059
${EXECUTORCH_COMMON_CMAKE_ARGS} \
6160
-B${BUILD_DIR} .
6261

@@ -93,7 +92,6 @@ cmake_build_llava_runner_for_android() {
9392
cmake \
9493
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
9594
-DANDROID_ABI=arm64-v8a \
96-
-DANDROID_PLATFORM=android-23 \
9795
${LLAVA_COMMON_CMAKE_ARGS} \
9896
-DCMAKE_PREFIX_PATH="$python_lib" \
9997
-DLLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE=ON \
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import os
9+
import re
10+
11+
from typing import List
12+
13+
# Provided by the PyGithub pip package.
14+
from github import Auth, Github
15+
from github.Repository import Repository
16+
17+
18+
def parse_args():
19+
parser = argparse.ArgumentParser(
20+
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
21+
)
22+
parser.add_argument(
23+
"--repo",
24+
type=str,
25+
help='The github repo to modify: e.g. "pytorch/executorch".',
26+
required=True,
27+
)
28+
parser.add_argument(
29+
"--pr",
30+
type=int,
31+
help="Number of the PR in the stack to check and create corresponding PR",
32+
required=True,
33+
)
34+
return parser.parse_args()
35+
36+
37+
def extract_stack_from_body(pr_body: str) -> List[int]:
38+
"""Extracts a list of PR numbers from a ghexport-generated PR body.
39+
40+
The base of the stack is in index 0.
41+
"""
42+
43+
# Expected format. The `__->__` could appear on any line. Stop parsing
44+
# after the blank line. This would return [1, 2, 3].
45+
"""
46+
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
47+
* #3
48+
* __->__ #2
49+
* #1
50+
51+
<PR description details>
52+
"""
53+
54+
prs = []
55+
ghstack_begin = (
56+
"Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):"
57+
)
58+
ghstack_begin_seen = False
59+
for line in pr_body.splitlines():
60+
if ghstack_begin in line:
61+
ghstack_begin_seen = True
62+
if not ghstack_begin_seen:
63+
continue
64+
match = re.match(r"\*(?:.*?)? #(\d+)", line)
65+
if match:
66+
# It's a bullet followed by an integer.
67+
prs.append(int(match.group(1)))
68+
return list(reversed(prs))
69+
70+
71+
def get_pr_stack_from_number(pr_number: int, repo: Repository) -> List[int]:
72+
pr_stack = extract_stack_from_body(repo.get_pull(pr_number).body)
73+
74+
if not pr_stack:
75+
raise Exception(
76+
f"Could not find PR stack in body of #{pr_number}. "
77+
+ "Please make sure that the PR was created with ghstack."
78+
)
79+
80+
return pr_stack
81+
82+
83+
def create_prs_for_orig_branch(pr_stack: List[int], repo: Repository):
84+
# For the first PR, we want to merge to `main` branch, and we will update
85+
# as we go through the stack
86+
orig_branch_merge_base = "main"
87+
for i in range(len(pr_stack)):
88+
pr = repo.get_pull(pr_stack[i])
89+
if not pr.is_merged():
90+
print("The PR (and stack above) is not merged yet, skipping")
91+
return
92+
# Check for invariant: For the current PR, it must be gh/user/x/base <- gh/user/x/head
93+
assert pr.base.ref.replace("base", "head") == pr.head.ref
94+
# The PR we want to create is then "branch_to_merge" <- gh/user/x/orig
95+
# gh/user/x/orig is the clean diff between gh/user/x/base <- gh/user/x/head
96+
orig_branch_merge_head = pr.base.ref.replace("base", "orig")
97+
bot_metadata = f"""This PR was created by the merge bot to help merge the original PR into the main branch.
98+
ghstack PR number: https://github.com/pytorch/executorch/pull/{pr.number}
99+
^ Please use this as the source of truth for the PR details, comments, and reviews
100+
ghstack PR base: https://github.com/pytorch/executorch/tree/{pr.base.ref}
101+
ghstack PR head: https://github.com/pytorch/executorch/tree/{pr.head.ref}
102+
Merge bot PR base: https://github.com/pytorch/executorch/tree/{orig_branch_merge_base}
103+
Merge bot PR head: https://github.com/pytorch/executorch/tree/{orig_branch_merge_head}"""
104+
105+
existing_orig_pr = repo.get_pulls(
106+
head="pytorch:" + orig_branch_merge_head,
107+
base=orig_branch_merge_base,
108+
state="open",
109+
)
110+
if existing_orig_pr.totalCount > 0:
111+
print(
112+
f"PR for {orig_branch_merge_head} already exists {existing_orig_pr[0]}"
113+
)
114+
# We don't need to create/edit because the head PR is merged and orig is finalized.
115+
else:
116+
repo.create_pull(
117+
base=orig_branch_merge_base,
118+
head=orig_branch_merge_head,
119+
title=pr.title,
120+
body=bot_metadata,
121+
)
122+
# Advance the base for the next PR
123+
orig_branch_merge_base = orig_branch_merge_head
124+
125+
126+
def main():
127+
args = parse_args()
128+
129+
with Github(auth=Auth.Token(os.environ["GITHUB_TOKEN"])) as gh:
130+
repo = gh.get_repo(args.repo)
131+
create_prs_for_orig_branch(get_pr_stack_from_number(args.pr, repo), repo)
132+
133+
134+
if __name__ == "__main__":
135+
main()

.github/workflows/ghstack_land.yml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: Propose to merge ghstack orig PRs to main
2+
on:
3+
pull_request:
4+
types: [closed]
5+
branches:
6+
- 'gh/cccclai/[0-9]+/base'
7+
- 'gh/dbort/[0-9]+/base'
8+
- 'gh/guangy10/[0-9]+/base'
9+
- 'gh/helunwencser/[0-9]+/base'
10+
- 'gh/jorgep31415/[0-9]+/base'
11+
- 'gh/kimishpatel/[0-9]+/base'
12+
- 'gh/kirklandsign/[0-9]+/base'
13+
- 'gh/larryliu0820/[0-9]+/base'
14+
- 'gh/manuelcandales/[0-9]+/base'
15+
- 'gh/mcr229/[0-9]+/base'
16+
- 'gh/swolchok/[0-9]+/base'
17+
- 'gh/SS-JIA/[0-9]+/base'
18+
19+
jobs:
20+
ghstack_merge_to_main:
21+
name: Try to create a PR with ghstack /orig branch
22+
runs-on: ubuntu-22.04
23+
environment: cherry-pick-bot
24+
steps:
25+
- uses: actions/checkout@v3
26+
with:
27+
fetch-depth: '0'
28+
- uses: actions/setup-python@v4
29+
with:
30+
python-version: '3.10'
31+
- name: Try to merge PR to main
32+
run: |
33+
pip install pygithub
34+
35+
PR_NUMBER=$(echo "$GITHUB_REF" | grep -oE '[0-9]+')
36+
37+
python .github/scripts/propose_ghstack_orig_pr.py --pr $PR_NUMBER --repo pytorch/executorch
38+
env:
39+
GITHUB_TOKEN: ${{ secrets.GH_PYTORCHBOT_CHERRY_PICK_TOKEN }}
40+
GITHUB_REF: ${{ github.ref }}

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,9 @@ endif()
372372
# Detect if an Android toolchain is set.
373373
if(CMAKE_TOOLCHAIN_FILE MATCHES ".*android\.toolchain\.cmake$")
374374
set(CMAKE_TOOLCHAIN_ANDROID ON)
375+
if(NOT ANDROID_PLATFORM)
376+
set(ANDROID_PLATFORM android-30)
377+
endif()
375378
else()
376379
set(CMAKE_TOOLCHAIN_ANDROID OFF)
377380
endif()

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import cast
1010

1111
import torch
12+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1213
from executorch.backends.arm.tosa_quant_utils import dq_op
1314
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1415
from executorch.exir.pass_base import ExportPass, PassResult
@@ -52,12 +53,7 @@ def call(self, graph_module: torch.fx.GraphModule):
5253
NHWC_Order = (0, 2, 3, 1)
5354
HWCM_Order = (2, 3, 0, 1)
5455
for node in graph_module.graph.nodes:
55-
if isinstance(
56-
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
57-
):
58-
node_data = node.meta["val"][0].data
59-
else:
60-
node_data = node.meta["val"].data
56+
node_data = get_first_fake_tensor(node).data
6157

6258
if len(node_data.shape) == 4:
6359
dim_order = NHWC_Order

backends/arm/_passes/arm_pass_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
2323
InsertSqueezeAfterSumPass,
2424
)
25+
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
2526
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
2627
ConvertMeanDimToAveragePool,
2728
)
@@ -30,6 +31,9 @@
3031
ScalarsToAttributePass,
3132
)
3233
from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
34+
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
35+
UnsqueezeScalarPlaceholdersPass,
36+
)
3337
from executorch.exir import ExportedProgram
3438
from executorch.exir.backend.compile_spec_schema import CompileSpec
3539
from executorch.exir.pass_manager import PassManager
@@ -45,10 +49,12 @@ def transform_to_backend_pipeline(
4549
):
4650
"""Apply passes before transforming program to backend"""
4751
self.add_pass(CastInt64ToInt32Pass(exported_program))
52+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
4853
self.add_pass(SizeAdjustConv2DPass())
4954
self.add_pass(RemoveClonePass())
5055
self.add_pass(ConvertExpandCopyToRepeatPass())
5156
self.add_pass(ConvertMeanDimToAveragePool())
57+
self.add_pass(MatchArgRanksPass(exported_program))
5258
self.add_pass(DecomposeDivPass())
5359
self.add_pass(InsertSqueezeAfterSumPass())
5460
self.add_pass(ConvertSplitToSlicePass())
@@ -61,6 +67,6 @@ def transform_to_backend_pipeline(
6167
return self._transform(exported_program.graph_module)
6268

6369
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
64-
self.add_pass(DecomposeDivPass())
6570
self.add_pass(ScalarsToAttributePass())
71+
self.add_pass(DecomposeDivPass())
6672
return self._transform(graph_module)

backends/arm/_passes/arm_pass_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from typing import Optional
88

99
import torch
10+
import torch.fx
1011

1112
from executorch.exir.dialects._ops import ops as exir_ops
1213
from torch._ops import OpOverload
14+
from torch._subclasses.fake_tensor import FakeTensor
1315

1416

1517
def create_node(
@@ -64,3 +66,21 @@ def insert_q_dq_pair(
6466
# node's first use
6567
q.args = (anchor,) + q_params
6668
return dq
69+
70+
71+
def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
72+
"""
73+
Returns a FakeTensor from the meta field of 'node'.
74+
If the node contains many fake tensors, return the first one.
75+
"""
76+
if isinstance(
77+
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
78+
):
79+
fake_tensor = node.meta["val"][0]
80+
else:
81+
fake_tensor = node.meta["val"]
82+
83+
assert isinstance(
84+
fake_tensor, FakeTensor
85+
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
86+
return fake_tensor

backends/arm/_passes/decompose_div_pass.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,18 @@
88
from executorch.exir.dialects._ops import ops as exir_ops
99
from executorch.exir.pass_base import ExportPass
1010

11+
edge_div_ops = (exir_ops.edge.aten.div.Tensor,)
12+
aten_div_ops = (torch.ops.aten.div.Tensor, torch.ops.aten.div_.Tensor)
13+
1114

1215
def get_div_decomposition(op) -> tuple:
1316
"""
1417
Returns the the (reciprocal_op, mul_op), where the ops depends on if
1518
the div op is in exir_ops torch.ops.aten.
1619
"""
17-
if op == exir_ops.edge.aten.div.Tensor:
20+
if op in edge_div_ops:
1821
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
19-
if op == torch.ops.aten.div.Tensor:
22+
if op in aten_div_ops:
2023
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
2124
raise RuntimeError(f"Can't get div decomposition for op {op}")
2225

@@ -33,7 +36,7 @@ class DecomposeDivPass(ExportPass):
3336
"""
3437

3538
def call_operator(self, op, args, kwargs, meta):
36-
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
39+
if op not in (edge_div_ops + aten_div_ops):
3740
return super().call_operator(op, args, kwargs, meta)
3841

3942
reciprocal_op, mul_op = get_div_decomposition(op)

0 commit comments

Comments
 (0)