Skip to content

Commit d576208

Browse files
author
morelos
committed
Update on "[ET] correcting cpu ref quantize_per_channel logic to align with ATen"
# Context The quantize_per_channel was not perfectly aligned with the ATen implementation, and demonstrated errors when specifying different axis. This bug wasn't distinctly acknowledged given that the test cases only has one test for the whole operator. In order to align more closely with ATen this change simply does a single loop imlpementation with direct channel index calculation over the old `apply_over_dim_list` approach. # Changes We change the core logic for quantize_per_channel to more properly align with ATen's implementation, and we also change it from `apply_over_dim_list` approach to a single loop implementation with direct channel index calculation. This also adds more comprehensive testing for quantize_per_channel so that a bug isn't missed again. Differential Revision: [D77746130](https://our.internmc.facebook.com/intern/diff/D77746130/) [ghstack-poisoned]
2 parents 96fcf40 + c37ce6c commit d576208

File tree

80 files changed

+2181
-1049
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+2181
-1049
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
9b498d3bb28b8e3411ce464dd2755c5b96d92c8f
1+
7cda4017ddda554752e89069ae205be5e8388f59

.ci/scripts/check_c10_sync.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ pushd pytorch
1212
git checkout "$pytorch_pin"
1313
popd
1414
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/c10 pytorch/c10
15-
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/torch/standalone pytorch/torch/standalone
15+
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/torch/headeronly pytorch/torch/headeronly

.github/workflows/trunk.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,11 @@ jobs:
240240
241241
cxx_flags="-fno-exceptions -fno-rtti -Wall -Werror -Wno-int-in-bool-context -DET_HAVE_PREAD=0"
242242
setup_script_args=""
243-
if [[ ${{ matrix.os}} == "bare_metal" ]]; then
243+
if [[ ${{ matrix.os}} == "bare_metal" ]]; then
244244
toolchain_prefix=arm-none-eabi-
245-
threshold="103268" # ~100KiB
245+
threshold="104000" # should be ~103.7KB, set threshold to 104KB.
246246
toolchain_cmake=examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake
247-
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
247+
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
248248
setup_script_args="--target-toolchain zephyr"
249249
toolchain_prefix=arm-zephyr-eabi-
250250
threshold="133120" # should be ~125KB, set threshold to 130KB

.lintrunner.toml

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ exclude_patterns = [
1010
'exir/serde/**',
1111
]
1212
command = [
13-
'python3',
13+
'python',
1414
'-m',
1515
'lintrunner_adapters',
1616
'run',
@@ -19,7 +19,7 @@ command = [
1919
'@{{PATHSFILE}}'
2020
]
2121
init_command = [
22-
'python3',
22+
'python',
2323
'-m',
2424
'lintrunner_adapters',
2525
'run',
@@ -41,7 +41,7 @@ exclude_patterns = [
4141
'exir/serde/**',
4242
]
4343
command = [
44-
'python3',
44+
'python',
4545
'-m',
4646
'lintrunner_adapters',
4747
'run',
@@ -50,7 +50,7 @@ command = [
5050
'@{{PATHSFILE}}'
5151
]
5252
init_command = [
53-
'python3',
53+
'python',
5454
'-m',
5555
'lintrunner_adapters',
5656
'run',
@@ -84,7 +84,7 @@ exclude_patterns = [
8484
'runtime/core/portable_type/c10/**',
8585
]
8686
command = [
87-
'python3',
87+
'python',
8888
'-m',
8989
'lintrunner_adapters',
9090
'run',
@@ -95,7 +95,7 @@ command = [
9595
'@{{PATHSFILE}}'
9696
]
9797
init_command = [
98-
'python3',
98+
'python',
9999
'-m',
100100
'lintrunner_adapters',
101101
'run',
@@ -117,7 +117,7 @@ exclude_patterns = [
117117
'**/third-party/**',
118118
]
119119
command = [
120-
'python3',
120+
'python',
121121
'-m',
122122
'lintrunner_adapters',
123123
'run',
@@ -127,7 +127,7 @@ command = [
127127
'@{{PATHSFILE}}',
128128
]
129129
init_command = [
130-
'python3',
130+
'python',
131131
'-m',
132132
'lintrunner_adapters',
133133
'run',
@@ -151,7 +151,7 @@ exclude_patterns = [
151151
'**/third-party/**',
152152
]
153153
command = [
154-
'python3',
154+
'python',
155155
'-m',
156156
'lintrunner_adapters',
157157
'run',
@@ -192,7 +192,7 @@ exclude_patterns = [
192192
'extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h',
193193
]
194194
command = [
195-
'python3',
195+
'python',
196196
'-m',
197197
'lintrunner_adapters',
198198
'run',
@@ -234,7 +234,7 @@ exclude_patterns = [
234234
'util/**',
235235
]
236236
command = [
237-
'python3',
237+
'python',
238238
'-m',
239239
'lintrunner_adapters',
240240
'run',
@@ -287,7 +287,7 @@ exclude_patterns = [
287287
'util/**',
288288
]
289289
command = [
290-
'python3',
290+
'python',
291291
'-m',
292292
'lintrunner_adapters',
293293
'run',
@@ -337,7 +337,7 @@ exclude_patterns = [
337337
'backends/arm/test/**',
338338
]
339339
command = [
340-
'python3',
340+
'python',
341341
'-m',
342342
'lintrunner_adapters',
343343
'run',
@@ -349,7 +349,7 @@ command = [
349349
'@{{PATHSFILE}}'
350350
]
351351
init_command = [
352-
'python3',
352+
'python',
353353
'-m',
354354
'lintrunner_adapters',
355355
'run',
@@ -368,7 +368,7 @@ exclude_patterns = [
368368
'.lintrunner.toml',
369369
]
370370
command = [
371-
'python3',
371+
'python',
372372
'-m',
373373
'lintrunner_adapters',
374374
'run',
@@ -397,7 +397,7 @@ exclude_patterns = [
397397
]
398398

399399
command = [
400-
"python3",
400+
"python",
401401
"-m",
402402
"lintrunner_adapters",
403403
"run",

CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ install(
490490
INCLUDES
491491
DESTINATION ${_common_include_directories}
492492
)
493-
install(FILES tools/cmake/executorch-config.cmake
493+
install(FILES tools/cmake/Utils.cmake tools/cmake/executorch-config.cmake
494494
DESTINATION lib/cmake/ExecuTorch
495495
)
496496

@@ -732,4 +732,8 @@ if(EXECUTORCH_BUILD_VULKAN)
732732
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/vulkan)
733733
endif()
734734

735+
if(EXECUTORCH_BUILD_ANDROID_JNI)
736+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/android)
737+
endif()
738+
735739
include(Test.cmake)

backends/cadence/aot/compiler.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from pathlib import Path
11-
from typing import Callable, cast, Optional
11+
from typing import Optional
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
@@ -32,7 +32,6 @@
3232
ExecutorchBackendConfig,
3333
ExecutorchProgramManager,
3434
)
35-
from executorch.exir.pass_base import PassResult
3635
from executorch.exir.passes import ToOutVarPass
3736
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
3837
from executorch.exir.program._program import to_edge_with_preserved_ops
@@ -41,7 +40,7 @@
4140
from torch.export.exported_program import ExportedProgram
4241
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4342

44-
from .passes import get_cadence_passes
43+
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
4544

4645
from .utils import print_ops_info
4746

@@ -210,6 +209,21 @@ def quantize_pt2(
210209
return program
211210

212211

212+
TO_EDGE_OP_EXCEPTION_LIST: list[torch._ops.OpOverload] = [
213+
torch.ops.aten._linalg_det.default,
214+
torch.ops.aten._linalg_svd.default,
215+
torch.ops.aten._native_batch_norm_legit_functional.default,
216+
torch.ops.aten.linear.default,
217+
torch.ops.aten.linalg_vector_norm.default,
218+
torch.ops.aten.unfold.default,
219+
torch.ops.aten.angle.default,
220+
torch.ops.aten.rms_norm.default,
221+
]
222+
TO_EDGE_PRESERVE_OPS: tuple[torch._ops.OpOverload, ...] = (
223+
torch.ops.aten.rms_norm.default,
224+
)
225+
226+
213227
def _lower_ep_to_edge(
214228
expo_program: ExportedProgram,
215229
dump_graphs: bool = False,
@@ -226,20 +240,11 @@ def _lower_ep_to_edge(
226240
compile_config=EdgeCompileConfig(
227241
_skip_dim_order=True,
228242
# Allow specific non-core aten ops in the IR.
229-
_core_aten_ops_exception_list=[
230-
torch.ops.aten._linalg_det.default,
231-
torch.ops.aten._linalg_svd.default,
232-
torch.ops.aten._native_batch_norm_legit_functional.default,
233-
torch.ops.aten.linear.default,
234-
torch.ops.aten.linalg_vector_norm.default,
235-
torch.ops.aten.unfold.default,
236-
torch.ops.aten.angle.default,
237-
torch.ops.aten.rms_norm.default,
238-
]
243+
_core_aten_ops_exception_list=TO_EDGE_OP_EXCEPTION_LIST
239244
+ (core_aten_exceptions or []),
240245
),
241246
constant_methods=constant_methods,
242-
preserve_ops=(torch.ops.aten.rms_norm.default,),
247+
preserve_ops=TO_EDGE_PRESERVE_OPS,
243248
)
244249

245250
if dump_graphs:
@@ -256,14 +261,20 @@ def export_to_edge(
256261
inputs: tuple[object, ...],
257262
dump_graphs: bool = False,
258263
constant_methods: Optional[dict[str, object]] = None,
264+
core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None,
259265
) -> EdgeProgramManager:
260266
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
261267

262268
# Export the model into an ExportedProgram.
263269
expo_program = trace(model, inputs)
264270

271+
# Apply passes which transform the ExportedProgram before it gets lowered to edge.
272+
expo_program = apply_torch_ops_passes(expo_program)
273+
265274
# Lower the model to edge IR.
266-
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods)
275+
edge_prog_manager = _lower_ep_to_edge(
276+
expo_program, dump_graphs, constant_methods, core_aten_exceptions
277+
)
267278

268279
return edge_prog_manager
269280

@@ -305,14 +316,7 @@ def _lower_ep_to_cadence(
305316
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
306317
"""
307318
edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs)
308-
cadence_passes = get_cadence_passes(opt_level)
309-
310-
# Run a couple required passes for quant/dequant ops
311-
cadence_prog_manager = edge_prog_manager.transform(
312-
cast(
313-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
314-
)
315-
)
319+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
316320
return cadence_prog_manager
317321

318322

@@ -323,14 +327,7 @@ def export_to_cadence(
323327
opt_level: int = 1,
324328
) -> EdgeProgramManager:
325329
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
326-
cadence_passes = get_cadence_passes(opt_level)
327-
328-
# Run a couple required passes for quant/dequant ops
329-
cadence_prog_manager = edge_prog_manager.transform(
330-
cast(
331-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
332-
)
333-
)
330+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
334331
return cadence_prog_manager
335332

336333

@@ -367,15 +364,8 @@ def export_to_executorch_gen_etrecord(
367364
memory_config: Optional[MemoryConfig] = None,
368365
dump_graphs: bool = False,
369366
) -> ExecutorchProgramManager:
370-
cadence_passes = get_cadence_passes(opt_level)
371367
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
372-
373-
# Run a couple required passes for quant/dequant ops
374-
cadence_prog_manager = edge_prog_manager.transform(
375-
cast(
376-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
377-
)
378-
)
368+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
379369

380370
# Print some information to terminal
381371
print_ops_info(

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ class CadenceFuseOpsInGraph:
11271127
FuseCascadedTransposeOrPermuteOps,
11281128
FuseCascadedViewOps,
11291129
FuseQuantDequantToRequantizePass,
1130+
FuseMulTensorIntoQuantPass,
11301131
FuseMulTensorIntoDequantPass,
11311132
FuseMulScalarIntoDequantPass,
11321133
FuseFullThenReshapePass,

0 commit comments

Comments
 (0)