Skip to content

Commit ef7af5c

Browse files
committed
Update
[ghstack-poisoned]
2 parents a27d18c + 5cc4941 commit ef7af5c

File tree

35 files changed

+781
-234
lines changed

35 files changed

+781
-234
lines changed

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ jobs:
269269
if [[ ${{ matrix.os}} == "bare_metal" ]]; then
270270
bash test/build_size_test.sh "-DCMAKE_TOOLCHAIN_FILE=${toolchain_cmake} -DEXECUTORCH_BUILD_ARM_BAREMETAL=ON"
271271
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
272-
CXXFLAGS=${cxx_flags} cmake --preset zephyr -DCMAKE_BUILD_TYPE=Release -DEXECUTORCH_OPTIMIZE_SIZE=ON -DCMAKE_INSTALL_PREFIX=cmake-out -Bcmake-out .
272+
CXXFLAGS=${cxx_flags} cmake --preset zephyr -DCMAKE_BUILD_TYPE=Release -DEXECUTORCH_OPTIMIZE_SIZE=ON -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON -DCMAKE_INSTALL_PREFIX=cmake-out -Bcmake-out .
273273
cmake --build cmake-out -j9 --target install --config Release
274274
CXXFLAGS=${cxx_flags} cmake -DCMAKE_TOOLCHAIN_FILE=${toolchain_cmake} -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=cmake-out -Bcmake-out/test test
275275
cmake --build cmake-out/test -j9 --config Release

backends/cadence/aot/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ python_library(
4141
":ops_registrations",
4242
":passes",
4343
":replace_ops",
44+
":compiler_funcs",
4445
":utils",
4546
"//caffe2:torch",
4647
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
@@ -332,6 +333,18 @@ python_library(
332333
],
333334
)
334335

336+
python_library(
337+
name = "compiler_funcs",
338+
srcs = [
339+
"compiler_funcs.py",
340+
],
341+
typing = True,
342+
deps = [
343+
"//caffe2:torch",
344+
"//pytorch/ao:torchao",
345+
],
346+
)
347+
335348

336349
python_unittest(
337350
name = "test_graph_builder",

backends/cadence/aot/compiler.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
15+
from executorch.backends.cadence.aot.compiler_funcs import (
16+
convert as convert_fn,
17+
prepare as prepare_fn,
18+
trace as trace_fn,
19+
)
1520
from executorch.backends.cadence.aot.memory_planning import (
1621
CadenceMemoryPlanning,
1722
print_memory_planning_info,
@@ -35,16 +40,13 @@
3540
from executorch.exir.passes import ToOutVarPass
3641
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
3742
from executorch.exir.program._program import to_edge
38-
from torch._inductor.decomposition import remove_decompositions
3943

4044
from torch.export.exported_program import ExportedProgram
41-
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4245

4346
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
4447

4548
from .utils import print_ops_info
4649

47-
4850
default_quantizer = CadenceDefaultQuantizer()
4951

5052

@@ -62,13 +64,6 @@ def trace(
6264
Trace the model with export and return an ExportedProgram.
6365
"""
6466

65-
# Make the model inference mode by calling model.eval()
66-
model.eval()
67-
68-
# Get default decompositions
69-
decomp_table = torch.export.default_decompositions()
70-
71-
# Select ops to keep
7267
ops_to_keep = [
7368
torch.ops.aten.conv1d.default,
7469
torch.ops.aten.conv2d.default,
@@ -78,63 +73,54 @@ def trace(
7873
torch.ops.aten.rms_norm.default,
7974
]
8075

81-
# Remove decompositions for the ops we want to keep
82-
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
83-
remove_decompositions(decomp_table, ops_to_keep)
84-
85-
# Export with dynamo
86-
program = torch.export.export(model, inputs, strict=True).run_decompositions(
87-
decomp_table
76+
program = trace_fn(
77+
model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep
8878
)
8979

9080
if dump_graphs:
9181
logging.info("Graph before quantization:")
92-
logging.info(program.module().graph.print_tabular())
82+
logging.info(program.graph_module.graph.print_tabular())
9383

9484
return program
9585

9686

97-
def prepare_and_convert_pt2(
87+
def prepare_pt2(
9888
program: ExportedProgram,
99-
inputs: tuple[object, ...],
10089
quantizer: CadenceQuantizer,
101-
calibration_data: Optional[list[tuple[object, ...]]] = None,
10290
dump_graphs: bool = False,
10391
) -> torch.fx.GraphModule:
10492
"""
105-
Prepare and convert a model using the given quantizer.
93+
Prepare a model using the given quantizer.
10694
The quantizer must be supplied and be the same as the one used to
10795
fuse the model later, if applicable. If you do not expect that behavior,
10896
please use quantize_and_fuse_pt2 instead, which will instantiate a
10997
default quantizer for you if needed.
110-
If calibration data is provided, it will be used to calibrate the model. If
111-
not, the inputs will be used for calibration instead, which is useful for
112-
unit tests but should not be used for end-to-end use cases.
113-
Returns a GraphModule with the converted model.
98+
Returns a GraphModule with the prepared model.
11499
"""
115100

116-
# Get the graph module from the ExportedProgram
117-
model_gm = program.module()
101+
prepared_model = prepare_fn(program, quantizer, is_qat=False)
118102

119-
assert isinstance(model_gm, torch.fx.GraphModule)
103+
if dump_graphs:
104+
logging.info("Graph after preparation:")
105+
logging.info(prepared_model.graph.print_tabular())
120106

121-
# Prepare
122-
prepared_model = prepare_pt2e(model_gm, quantizer)
107+
return prepared_model
123108

124-
# Calibrate
125-
# If no calibration data is provided, use the inputs
126-
if calibration_data is None:
127-
calibration_data = [inputs]
128109

129-
for samples in calibration_data:
130-
prepared_model(*samples)
110+
def convert_pt2(
111+
graph_module: torch.fx.GraphModule,
112+
dump_graphs: bool = False,
113+
) -> torch.fx.GraphModule:
114+
"""
115+
Convert the model
116+
Returns a GraphModule with the converted model.
117+
"""
131118

132-
# Convert
133-
converted_model = convert_pt2e(prepared_model)
119+
converted_model = convert_fn(graph_module)
134120

135121
if dump_graphs:
136-
logging.info("Graph after quantization (before fusion):")
137-
logging.info(model_gm.graph.print_tabular())
122+
logging.info("Graph after convert:")
123+
logging.info(converted_model.graph.print_tabular())
138124

139125
return converted_model
140126

@@ -192,10 +178,19 @@ def quantize_pt2(
192178
logging.info("Graph after trace:")
193179
logging.info(program.graph.print_tabular())
194180

181+
# Get prepared graph module
182+
prepared_gm = prepare_pt2(program, quantizer, dump_graphs=dump_graphs)
183+
184+
# Calibrate
185+
# If no calibration data is provided, use the inputs
186+
if calibration_data is None:
187+
calibration_data = [inputs]
188+
189+
for samples in calibration_data:
190+
prepared_gm(*samples)
191+
195192
# Get converted graph module
196-
converted_gm = prepare_and_convert_pt2(
197-
program, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
198-
)
193+
converted_gm = convert_pt2(prepared_gm, dump_graphs=dump_graphs)
199194

200195
# Get fused model
201196
fused_gm = fuse_pt2(converted_gm, quantizer)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
10+
from typing import Optional
11+
12+
import torch
13+
from torch._inductor.decomposition import remove_decompositions
14+
from torchao.quantization.pt2e.quantize_pt2e import (
15+
convert_pt2e,
16+
prepare_pt2e,
17+
prepare_qat_pt2e,
18+
)
19+
from torchao.quantization.pt2e.quantizer import Quantizer
20+
21+
22+
@torch.no_grad()
23+
def trace(
24+
model: torch.nn.Module,
25+
inputs: tuple[object, ...],
26+
is_qat: bool = False,
27+
strict: bool = False,
28+
ops_to_keep: Optional[list[torch._ops.OpOverload]] = None,
29+
) -> torch.export.ExportedProgram:
30+
if is_qat:
31+
model.train()
32+
else:
33+
model.eval()
34+
35+
decomp_table = torch.export.default_decompositions()
36+
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
37+
remove_decompositions(decomp_table, ops_to_keep)
38+
program = torch.export.export_for_training(
39+
model, inputs, strict=strict
40+
).run_decompositions(decomp_table)
41+
42+
return program
43+
44+
45+
def prepare(
46+
traced_program: torch.export.ExportedProgram,
47+
quantizer: Quantizer,
48+
is_qat: bool = False,
49+
) -> torch.fx.GraphModule:
50+
traced_model = traced_program.module()
51+
assert isinstance(traced_model, torch.fx.GraphModule)
52+
53+
if is_qat:
54+
prepared_model = prepare_qat_pt2e(traced_model, quantizer)
55+
else:
56+
prepared_model = prepare_pt2e(traced_model, quantizer)
57+
58+
return prepared_model
59+
60+
61+
def convert(prepared_model: torch.fx.GraphModule) -> torch.fx.GraphModule:
62+
converted_model = convert_pt2e(prepared_model)
63+
return converted_model

backends/cadence/aot/export_example.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
from typing import Any, Tuple
1616

1717
from executorch.backends.cadence.aot.compiler import (
18+
convert_pt2,
1819
export_to_executorch_gen_etrecord,
1920
fuse_pt2,
20-
prepare_and_convert_pt2,
21+
prepare_pt2,
2122
trace,
2223
)
2324

@@ -52,8 +53,15 @@ def export_model(
5253
# Trace the model
5354
ep = trace(model, example_inputs)
5455

56+
# Prepare the model
57+
prepared_gm = prepare_pt2(ep, quantizer)
58+
59+
# Calibrate the model
60+
for samples in [example_inputs]:
61+
prepared_gm(*samples)
62+
5563
# Convert the model
56-
converted_model = prepare_and_convert_pt2(ep, example_inputs, quantizer)
64+
converted_model = convert_pt2(prepared_gm)
5765

5866
# Get reference outputs from converted model
5967
ref_outputs = converted_model(*example_inputs)

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ class LayoutTransform(ExportPass):
103103
exir_ops.edge.aten.pow.Tensor_Scalar,
104104
exir_ops.edge.aten.prelu.default,
105105
exir_ops.edge.aten.repeat.default,
106-
exir_ops.edge.aten.round.default,
107106
exir_ops.edge.aten.relu.default,
107+
exir_ops.edge.aten.round.default,
108108
exir_ops.edge.aten.sigmoid.default,
109109
exir_ops.edge.aten.split_with_sizes.default,
110110
exir_ops.edge.aten.split_with_sizes_copy.default,

backends/qualcomm/quantizer/annotators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,9 @@ def annotate_masked_fill(node: Node, quantization_config: QuantizationConfig) ->
278278
)
279279

280280

281-
@register_annotator([torch.ops.aten.mul, torch.ops.aten.mul.Tensor])
281+
@register_annotator(
282+
[torch.ops.aten.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor]
283+
)
282284
def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None:
283285
annotate_binary(node, quantization_config)
284286

@@ -1311,7 +1313,7 @@ def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None:
13111313
)
13121314

13131315

1314-
@register_annotator([torch.ops.aten.zeros.default])
1316+
@register_annotator([torch.ops.aten.zeros.default, torch.ops.aten.zeros_like.default])
13151317
def annotate_zeros(node: Node, quantization_config: QuantizationConfig) -> None:
13161318
if _is_annotated([node]) or not _is_float_tensor(node):
13171319
return

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
153153
)
154154

155155

156-
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
156+
def annotate_matmul_16a8w( # noqa: C901
157+
gm: torch.fx.GraphModule, annotate_conv=True
158+
) -> None:
157159
"""
158160
This function is specific for matmul op 16a8w.
159161
For k, we will tag such as the below, and
@@ -317,9 +319,10 @@ def annotate_matmul_input1(node: Node):
317319
# The arguments of cat op: (the past kv cache, the new kv cache)
318320
node = node.args[0][1]
319321
elif node.target == torch.ops.aten.conv2d.default:
320-
annotate_conv2d(
321-
node, quantization_config=quantization_config_8a4w_per_channel
322-
)
322+
if annotate_conv:
323+
annotate_conv2d(
324+
node, quantization_config=quantization_config_8a4w_per_channel
325+
)
323326
break
324327
elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]:
325328
break

backends/qualcomm/runtime/backends/QnnOpPackageManager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88
#pragma once
99
#include <mutex>
10+
#include <string>
1011
#include <unordered_set>
1112

1213
namespace executorch {

backends/qualcomm/scripts/build.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ if [ "$BUILD_AARCH64" = true ]; then
8585
-DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \
8686
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
8787
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
88+
-DEXECUTORCH_ENABLE_LOGGING=ON \
8889
-DQNN_SDK_ROOT=$QNN_SDK_ROOT \
8990
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \
9091
-DANDROID_ABI='arm64-v8a' \
@@ -104,6 +105,9 @@ if [ "$BUILD_AARCH64" = true ]; then
104105
-DANDROID_ABI='arm64-v8a' \
105106
-DANDROID_PLATFORM=android-30 \
106107
-DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \
108+
-DSUPPORT_REGEX_LOOKAHEAD=ON \
109+
-DBUILD_TESTING=OFF \
110+
-DEXECUTORCH_ENABLE_LOGGING=ON \
107111
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
108112
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
109113
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
@@ -134,6 +138,7 @@ if [ "$BUILD_X86_64" = true ]; then
134138
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
135139
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
136140
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
141+
-DEXECUTORCH_ENABLE_LOGGING=ON \
137142
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
138143
-S $PRJ_ROOT \
139144
-B $BUILD_ROOT \
@@ -157,6 +162,9 @@ if [ "$BUILD_X86_64" = true ]; then
157162
-DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \
158163
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
159164
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
165+
-DSUPPORT_REGEX_LOOKAHEAD=ON \
166+
-DBUILD_TESTING=OFF \
167+
-DEXECUTORCH_ENABLE_LOGGING=ON \
160168
-B$EXAMPLE_ROOT
161169

162170
cmake --build $EXAMPLE_ROOT -j$BUILD_JOB_NUMBER

0 commit comments

Comments
 (0)