Skip to content

Commit bfe548d

Browse files
committed
Update on "[ET-VK] Replace Uniform buffers with push constants for permute op"
This diff replaces uniform buffers with push constants for permute op in the Vulkan backend of Executorch. The changes include updating the GLSL code to use push constants instead of uniform buffers and updating the C++ code to pass the sizes as push constants to the shader. Differential Revision: [D66890825](https://our.internmc.facebook.com/intern/diff/D66890825/) [ghstack-poisoned]
2 parents 5a4de9d + 96fb18e commit bfe548d

File tree

77 files changed

+2449
-417
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+2449
-417
lines changed

.ci/docker/ubuntu/Dockerfile

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ RUN if [ -n "${ANDROID_NDK_VERSION}" ]; then bash ./install_android.sh; fi
7979
RUN rm install_android.sh
8080

8181
ARG ARM_SDK
82-
COPY --chown=ci-user:ci-user ./arm /opt/arm
83-
# Set up ARM SDK if needed
84-
RUN if [ -n "${ARM_SDK}" ]; then git config --global user.email "[email protected]"; git config --global user.name "OSS CI"; bash /opt/arm/setup.sh --i-agree-to-the-contained-eula /opt/arm-sdk; chown -R ci-user:ci-user /opt/arm-sdk; fi
8582

8683
ARG QNN_SDK
8784

.ci/scripts/build_llama_android.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"
1212

1313
install_executorch_and_backend_lib() {
1414
echo "Installing executorch and xnnpack backend"
15-
rm -rf cmake-android-out && mkdir cmake-android-out
15+
clean_executorch_install_folders
16+
mkdir cmake-android-out
1617
ANDROID_NDK=/opt/ndk
1718
BUCK2=buck2
1819
ANDROID_ABI=arm64-v8a

.ci/scripts/utils.sh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ retry () {
1616
"$@" || (sleep 30 && reset_buck && "$@") || (sleep 60 && reset_buck && "$@")
1717
}
1818

19+
clean_executorch_install_folders() {
20+
./install_requirements.sh --clean
21+
}
22+
1923
install_executorch() {
2024
which pip
2125
# Install executorch, this assumes that Executorch is checked out in the
@@ -74,7 +78,8 @@ build_executorch_runner_buck2() {
7478
build_executorch_runner_cmake() {
7579
CMAKE_OUTPUT_DIR=cmake-out
7680
# Build executorch runtime using cmake
77-
rm -rf "${CMAKE_OUTPUT_DIR}" && mkdir "${CMAKE_OUTPUT_DIR}"
81+
clean_executorch_install_folders
82+
mkdir "${CMAKE_OUTPUT_DIR}"
7883

7984
pushd "${CMAKE_OUTPUT_DIR}" || return
8085
# This command uses buck2 to gather source files and buck2 could crash flakily
@@ -103,7 +108,7 @@ build_executorch_runner() {
103108

104109
cmake_install_executorch_lib() {
105110
echo "Installing libexecutorch.a and libportable_kernels.a"
106-
rm -rf cmake-out
111+
clean_executorch_install_folders
107112
retry cmake -DBUCK2="$BUCK" \
108113
-DCMAKE_INSTALL_PREFIX=cmake-out \
109114
-DCMAKE_BUILD_TYPE=Release \

.github/workflows/docker-builds.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ on:
77
- .ci/docker/**
88
- .github/workflows/docker-builds.yml
99
- requirements-lintrunner.txt
10-
- examples/arm/setup.sh
11-
- examples/arm/ethos-u-setup/**
1210
push:
1311
branches:
1412
- main
@@ -17,8 +15,6 @@ on:
1715
- .ci/docker/**
1816
- .github/workflows/docker-builds.yml
1917
- requirements-lintrunner.txt
20-
- examples/arm/setup.sh
21-
- examples/arm/ethos-u-setup/**
2218
schedule:
2319
- cron: 1 3 * * 3
2420

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
.hypothesis
22
buck-out/
3+
buck2-bin/
34
cmake-out*
45
.DS_Store
56
cmake-android-out/

CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,17 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
257257
set(EXECUTORCH_BUILD_KERNELS_OPTIMIZED ON)
258258
endif()
259259

260+
if(NOT DEFINED FXDIV_SOURCE_DIR)
261+
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG
262+
${CMAKE_POSITION_INDEPENDENT_CODE}
263+
)
264+
set(FXDIV_SOURCE_DIR "backends/xnnpack/third-party/FXdiv")
265+
add_subdirectory("${FXDIV_SOURCE_DIR}")
266+
set(CMAKE_POSITION_INDEPENDENT_CODE
267+
${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG}
268+
)
269+
endif()
270+
260271
if(EXECUTORCH_BUILD_CPUINFO)
261272
# --- cpuinfo
262273
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 117 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.backends.arm._passes.arm_pass_utils import (
1313
create_node,
1414
get_first_fake_tensor,
15+
get_node_arg,
1516
insert_q_dq_pair,
1617
)
1718
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
@@ -83,14 +84,48 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
8384

8485
return False
8586

86-
def insert_input_transpose(self, node, input_node, graph_module):
87+
@staticmethod
88+
def memory_format_differs(shape):
89+
"""Returns true if the shape will have a different memory layout in NCHW and NHWC format"""
90+
if len(shape) >= 4:
91+
C = shape[1]
92+
H = shape[2]
93+
W = shape[3]
94+
elif len(shape) == 3:
95+
C = shape[0]
96+
H = shape[1]
97+
W = shape[2]
98+
if len(shape) <= 2:
99+
return False
100+
101+
return C > 1 and (H > 1 or W > 1)
102+
103+
@staticmethod
104+
def is_channel_reshape(input_shape, output_shape):
105+
"""Returns true if the reshape changes the channel dimension"""
106+
if not len(input_shape) == len(output_shape) == 4:
107+
return False
108+
109+
C_old = input_shape[1]
110+
C_new = output_shape[1]
111+
112+
N_new = output_shape[0]
113+
N_old = input_shape[0]
114+
115+
return (N_old != N_new) or (C_old != C_new)
116+
117+
@staticmethod
118+
def insert_input_transpose(node, input_node, graph_module):
87119
quantize = input_node.target == dq_op
88120
q_params = input_node.args[1:] if quantize else None
89121
with graph_module.graph.inserting_before(node):
90122
permute_node = create_node(
91123
graph_module.graph,
92124
torch.ops.passthrough_to_tosa._transpose,
93-
args=(input_node, list(self.NHWC_inverse_order)),
125+
args=(
126+
input_node,
127+
list(AnnotateChannelsLastDimOrder.NHWC_inverse_order),
128+
),
94129
quantize=quantize,
95130
q_params=q_params,
96131
)
@@ -100,14 +135,17 @@ def insert_input_transpose(self, node, input_node, graph_module):
100135
range(len(input_node.meta["val"].size()))
101136
)
102137

103-
def insert_output_transpose(self, node, graph_module):
138+
@staticmethod
139+
def insert_output_transpose(node, graph_module):
104140
with graph_module.graph.inserting_after(node):
105141
permute_node = create_node(
106142
graph_module.graph,
107143
torch.ops.passthrough_to_tosa._transpose,
108-
args=(node, list(self.NHWC_order)),
144+
args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)),
145+
)
146+
permute_node.meta["tosa_dim_order"] = (
147+
AnnotateChannelsLastDimOrder.NHWC_order
109148
)
110-
permute_node.meta["tosa_dim_order"] = self.NHWC_order
111149
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
112150
users = [user for user in node.users if user != permute_node]
113151
for user in users:
@@ -118,54 +156,96 @@ def insert_output_transpose(self, node, graph_module):
118156
q_params = node.args[0].args[1:]
119157
insert_q_dq_pair(graph_module.graph, node, q_params)
120158

159+
@staticmethod
160+
def _insert_squeeze_transpose(
161+
input_shape, output_shape, node, input_node, graph_module
162+
):
163+
nhwc_to_nhwc = len(input_shape) == 4 and len(output_shape) <= 3
164+
165+
if nhwc_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
166+
input_shape
167+
):
168+
AnnotateChannelsLastDimOrder.insert_input_transpose(
169+
node, input_node, graph_module
170+
)
171+
172+
@staticmethod
173+
def _insert_unsqueeze_transpose(input_shape, output_shape, node, graph_module):
174+
nchw_to_nhwc = len(input_shape) == 3 and len(output_shape) == 4
175+
if nchw_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
176+
output_shape
177+
):
178+
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)
179+
180+
@staticmethod
181+
def _insert_view_transpose(
182+
input_shape, output_shape, node, input_node, graph_module
183+
):
184+
nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) == 4
185+
nhwc_to_nchw = len(input_shape) == 4 and len(output_shape) < 4
186+
channel_reshape = AnnotateChannelsLastDimOrder.is_channel_reshape(
187+
output_shape, input_shape
188+
)
189+
190+
if (
191+
channel_reshape or nhwc_to_nchw
192+
) and AnnotateChannelsLastDimOrder.memory_format_differs(input_shape):
193+
AnnotateChannelsLastDimOrder.insert_input_transpose(
194+
node, input_node, graph_module
195+
)
196+
if (
197+
channel_reshape or nchw_to_nhwc
198+
) and AnnotateChannelsLastDimOrder.memory_format_differs(output_shape):
199+
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)
200+
121201
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
122202
"""
123-
Reshape operations are not equivalent in NCHW and NHWC.
124-
To get around this, transposes need to be added if the previous or new shape
125-
fulfil the following condition:
126-
C > 1 and (H or W > 1)
127-
128-
This is relevant for the following operations;
129-
squeeze: 4D -> 3D
130-
unsqueeze: <4D -> 4D
131-
view: <4D -> 4D
132-
view: 4D -> <4D
133-
view: 4D -> 4D
134-
"""
135-
136-
def transpose_condition(shape):
137-
if len(shape) != 4:
138-
return False
139-
C = shape[1]
140-
H = shape[2]
141-
W = shape[3]
142-
return C > 1 and (H > 1 or W > 1)
203+
Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
204+
This is relevant for the following cases:
205+
- squeeze: 4D -> <4D
206+
- unsqueeze: 3D -> 4D
207+
- view: <4D -> 4D
208+
- view: 4D -> <4D
209+
Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.
143210
211+
Transposes can be avoided for shapes where there is no difference in actual memory, e.g for
212+
- H == W == 1
213+
- C == 1
214+
- 1D/2D tensors
215+
"""
144216
for node in graph_module.graph.nodes:
145217
if node.op != "call_function":
146218
continue
219+
147220
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
148221
input_node = node.args[0]
149222
input_shape = input_node.meta["val"].shape
150-
if transpose_condition(input_shape):
151-
self.insert_input_transpose(node, input_node, graph_module)
223+
output_shape = node.meta["val"].shape
224+
225+
self._insert_squeeze_transpose(
226+
input_shape, output_shape, node, input_node, graph_module
227+
)
152228

153229
elif node.target == exir_ops.edge.aten.unsqueeze_copy.default:
230+
input_node = get_node_arg(node.args, 0, default_value=False)
231+
if input_node:
232+
input_shape = input_node.meta["val"].shape
233+
else:
234+
input_shape = ()
154235
output_shape = node.meta["val"].shape
155-
if transpose_condition(output_shape):
156-
self.insert_output_transpose(node, graph_module)
236+
237+
self._insert_unsqueeze_transpose(
238+
input_shape, output_shape, node, graph_module
239+
)
157240

158241
elif node.target == exir_ops.edge.aten.view_copy.default:
159242
input_node = node.args[0]
243+
input_shape = input_node.meta["val"].shape
244+
output_shape = node.meta["val"].shape
160245

161-
old_shape = input_node.meta["val"].shape
162-
new_shape = node.meta["val"].shape
163-
164-
if transpose_condition(old_shape):
165-
self.insert_input_transpose(node, input_node, graph_module)
166-
167-
if transpose_condition(new_shape):
168-
self.insert_output_transpose(node, graph_module)
246+
self._insert_view_transpose(
247+
input_shape, output_shape, node, input_node, graph_module
248+
)
169249

170250
def call(self, graph_module: torch.fx.GraphModule):
171251
for node in graph_module.graph.nodes:

backends/arm/test/ops/test_view.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class View(torch.nn.Module):
4343
(torch.rand(1, 1, 5, 10), (1, 1, 50, 1)),
4444
(torch.rand(5, 10, 1, 1), (1, 25, 2)),
4545
(torch.rand(2, 50, 1, 1), (1, 100)),
46+
(torch.rand(2, 3, 2, 3), (2, 3, 3, 2)),
4647
]
4748

4849
def forward(self, x: torch.Tensor, new_shape):

backends/arm/tosa_quant_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def quantize_value(x, qargs: QuantArgs, dtype=np.int8):
7171

7272

7373
def dequantize_value(qx, qargs: QuantArgs):
74-
return (qx - qargs.zp) * qargs.scale
74+
return (np.int64(qx) - qargs.zp) * qargs.scale
7575

7676

7777
def qargs_from_qnode(node: torch.fx.Node):

0 commit comments

Comments
 (0)