Skip to content

Commit a337a5c

Browse files
authored
Merge branch 'main' into gh/ahmtox/24/orig
2 parents 9aba8d6 + b5f950b commit a337a5c

File tree

57 files changed

+1638
-936
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1638
-936
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
9b498d3bb28b8e3411ce464dd2755c5b96d92c8f
1+
7cda4017ddda554752e89069ae205be5e8388f59

.ci/scripts/check_c10_sync.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ pushd pytorch
1212
git checkout "$pytorch_pin"
1313
popd
1414
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/c10 pytorch/c10
15-
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/torch/standalone pytorch/torch/standalone
15+
"$(dirname "${BASH_SOURCE[0]}")"/compare_dirs.sh runtime/core/portable_type/c10/torch/headeronly pytorch/torch/headeronly

.github/workflows/trunk.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,11 @@ jobs:
240240
241241
cxx_flags="-fno-exceptions -fno-rtti -Wall -Werror -Wno-int-in-bool-context -DET_HAVE_PREAD=0"
242242
setup_script_args=""
243-
if [[ ${{ matrix.os}} == "bare_metal" ]]; then
243+
if [[ ${{ matrix.os}} == "bare_metal" ]]; then
244244
toolchain_prefix=arm-none-eabi-
245-
threshold="103268" # ~100KiB
245+
threshold="104000" # should be ~103.7KB, set threshold to 104KB.
246246
toolchain_cmake=examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake
247-
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
247+
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
248248
setup_script_args="--target-toolchain zephyr"
249249
toolchain_prefix=arm-zephyr-eabi-
250250
threshold="133120" # should be ~125KB, set threshold to 130KB

CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ install(
490490
INCLUDES
491491
DESTINATION ${_common_include_directories}
492492
)
493-
install(FILES tools/cmake/executorch-config.cmake
493+
install(FILES tools/cmake/Utils.cmake tools/cmake/executorch-config.cmake
494494
DESTINATION lib/cmake/ExecuTorch
495495
)
496496

@@ -732,4 +732,8 @@ if(EXECUTORCH_BUILD_VULKAN)
732732
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/vulkan)
733733
endif()
734734

735+
if(EXECUTORCH_BUILD_ANDROID_JNI)
736+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/android)
737+
endif()
738+
735739
include(Test.cmake)

backends/cadence/aot/compiler.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from pathlib import Path
11-
from typing import Callable, cast, Optional
11+
from typing import Optional
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
@@ -32,7 +32,6 @@
3232
ExecutorchBackendConfig,
3333
ExecutorchProgramManager,
3434
)
35-
from executorch.exir.pass_base import PassResult
3635
from executorch.exir.passes import ToOutVarPass
3736
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
3837
from executorch.exir.program._program import to_edge_with_preserved_ops
@@ -41,7 +40,7 @@
4140
from torch.export.exported_program import ExportedProgram
4241
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4342

44-
from .passes import get_cadence_passes
43+
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
4544

4645
from .utils import print_ops_info
4746

@@ -210,6 +209,21 @@ def quantize_pt2(
210209
return program
211210

212211

212+
TO_EDGE_OP_EXCEPTION_LIST: list[torch._ops.OpOverload] = [
213+
torch.ops.aten._linalg_det.default,
214+
torch.ops.aten._linalg_svd.default,
215+
torch.ops.aten._native_batch_norm_legit_functional.default,
216+
torch.ops.aten.linear.default,
217+
torch.ops.aten.linalg_vector_norm.default,
218+
torch.ops.aten.unfold.default,
219+
torch.ops.aten.angle.default,
220+
torch.ops.aten.rms_norm.default,
221+
]
222+
TO_EDGE_PRESERVE_OPS: tuple[torch._ops.OpOverload, ...] = (
223+
torch.ops.aten.rms_norm.default,
224+
)
225+
226+
213227
def _lower_ep_to_edge(
214228
expo_program: ExportedProgram,
215229
dump_graphs: bool = False,
@@ -226,20 +240,11 @@ def _lower_ep_to_edge(
226240
compile_config=EdgeCompileConfig(
227241
_skip_dim_order=True,
228242
# Allow specific non-core aten ops in the IR.
229-
_core_aten_ops_exception_list=[
230-
torch.ops.aten._linalg_det.default,
231-
torch.ops.aten._linalg_svd.default,
232-
torch.ops.aten._native_batch_norm_legit_functional.default,
233-
torch.ops.aten.linear.default,
234-
torch.ops.aten.linalg_vector_norm.default,
235-
torch.ops.aten.unfold.default,
236-
torch.ops.aten.angle.default,
237-
torch.ops.aten.rms_norm.default,
238-
]
243+
_core_aten_ops_exception_list=TO_EDGE_OP_EXCEPTION_LIST
239244
+ (core_aten_exceptions or []),
240245
),
241246
constant_methods=constant_methods,
242-
preserve_ops=(torch.ops.aten.rms_norm.default,),
247+
preserve_ops=TO_EDGE_PRESERVE_OPS,
243248
)
244249

245250
if dump_graphs:
@@ -256,14 +261,20 @@ def export_to_edge(
256261
inputs: tuple[object, ...],
257262
dump_graphs: bool = False,
258263
constant_methods: Optional[dict[str, object]] = None,
264+
core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None,
259265
) -> EdgeProgramManager:
260266
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
261267

262268
# Export the model into an ExportedProgram.
263269
expo_program = trace(model, inputs)
264270

271+
# Apply passes which transform the ExportedProgram before it gets lowered to edge.
272+
expo_program = apply_torch_ops_passes(expo_program)
273+
265274
# Lower the model to edge IR.
266-
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods)
275+
edge_prog_manager = _lower_ep_to_edge(
276+
expo_program, dump_graphs, constant_methods, core_aten_exceptions
277+
)
267278

268279
return edge_prog_manager
269280

@@ -305,14 +316,7 @@ def _lower_ep_to_cadence(
305316
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
306317
"""
307318
edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs)
308-
cadence_passes = get_cadence_passes(opt_level)
309-
310-
# Run a couple required passes for quant/dequant ops
311-
cadence_prog_manager = edge_prog_manager.transform(
312-
cast(
313-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
314-
)
315-
)
319+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
316320
return cadence_prog_manager
317321

318322

@@ -323,14 +327,7 @@ def export_to_cadence(
323327
opt_level: int = 1,
324328
) -> EdgeProgramManager:
325329
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
326-
cadence_passes = get_cadence_passes(opt_level)
327-
328-
# Run a couple required passes for quant/dequant ops
329-
cadence_prog_manager = edge_prog_manager.transform(
330-
cast(
331-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
332-
)
333-
)
330+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
334331
return cadence_prog_manager
335332

336333

@@ -367,15 +364,8 @@ def export_to_executorch_gen_etrecord(
367364
memory_config: Optional[MemoryConfig] = None,
368365
dump_graphs: bool = False,
369366
) -> ExecutorchProgramManager:
370-
cadence_passes = get_cadence_passes(opt_level)
371367
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
372-
373-
# Run a couple required passes for quant/dequant ops
374-
cadence_prog_manager = edge_prog_manager.transform(
375-
cast(
376-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
377-
)
378-
)
368+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
379369

380370
# Print some information to terminal
381371
print_ops_info(

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ class CadenceFuseOpsInGraph:
11271127
FuseCascadedTransposeOrPermuteOps,
11281128
FuseCascadedViewOps,
11291129
FuseQuantDequantToRequantizePass,
1130+
FuseMulTensorIntoQuantPass,
11301131
FuseMulTensorIntoDequantPass,
11311132
FuseMulScalarIntoDequantPass,
11321133
FuseFullThenReshapePass,

backends/cadence/aot/passes.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Any, List, Optional
9+
from typing import Any, Callable, cast, List, Optional
1010

1111
import torch
1212
import torch.fx
@@ -28,13 +28,18 @@
2828
RemoveRedundantOps,
2929
)
3030
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
31-
from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph
31+
from executorch.backends.cadence.aot.replace_ops import (
32+
CadenceReplaceOpsInGraph,
33+
ReplaceMulTensorWithMulAndFullOpsPass,
34+
)
3235
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
36+
from executorch.exir import EdgeProgramManager
3337
from executorch.exir.pass_base import ExportPass, PassResult
3438
from executorch.exir.pass_manager import PassManager, PassType
3539
from executorch.exir.passes import dead_code_elimination_pass
3640
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
3741
from executorch.exir.passes.spec_prop_pass import SpecPropPass
42+
from torch.export.exported_program import ExportedProgram
3843

3944

4045
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -89,14 +94,37 @@ def get_passes_in_default_order() -> List[ExportPass]:
8994
return pytree.tree_flatten(passes)[0]
9095

9196

92-
def get_cadence_passes(
97+
def apply_exir_ops_passes(
9398
opt_level: int,
94-
) -> List[Optional[PassResult]]:
99+
edge_prog_manager: EdgeProgramManager,
100+
) -> EdgeProgramManager:
95101
passes = get_passes_in_default_order()
96102
pass_filter = create_cadence_pass_filter(opt_level)
97-
filtered_passes = [
98-
# pyre-ignore[20]: Expect argument graph_module
99-
filtered_pass()
103+
cadence_passes = [
104+
(
105+
lambda graph_module, filtered_pass=filtered_pass: filtered_pass()(
106+
graph_module
107+
)
108+
)
100109
for filtered_pass in list(filter(pass_filter, passes))
101110
]
102-
return filtered_passes
111+
cadence_prog_manager = edge_prog_manager.transform(
112+
cast(
113+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
114+
)
115+
)
116+
return cadence_prog_manager
117+
118+
119+
def apply_torch_ops_passes(expo_program: ExportedProgram) -> ExportedProgram:
120+
"""
121+
Applies compiler passes on torch.ops IR, including torch.ops.aten, torch.ops.cadence, etc.
122+
expo_program is expected to be the output of the torch.export.export().
123+
"""
124+
125+
aten_passes: List[Callable[[torch.fx.GraphModule], Optional[PassResult]]] = [
126+
ReplaceMulTensorWithMulAndFullOpsPass()
127+
]
128+
# TODO(T230417247): Use PassResult which is currently ignored.
129+
PassManager(aten_passes)(expo_program.graph_module)
130+
return expo_program

backends/nxp/backend/edge_helper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from torch.fx import Node
8+
from torch.nn import Parameter
89

910

1011
def input_tensor(node: Node, input_index: int) -> torch.Tensor:
@@ -38,3 +39,35 @@ def input_tensor_safe(node: Node, input_index: int) -> torch.Tensor | None:
3839
return None
3940

4041
return input_tensor(node, input_index)
42+
43+
44+
def node_is_static_tensor(node: Node, parameters_mapping: dict[str, Parameter]) -> bool:
45+
"""Return `True` if the given `node` has static data in the `parameters_mapping` dict.
46+
:param node: Tensor node to check for data.
47+
:param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the
48+
`state_dict` attribute of an edge program.
49+
"""
50+
return node.name in parameters_mapping.keys()
51+
52+
53+
def node_is_effectively_static_tensor(
54+
node: Node, parameters_mapping: dict[str, Parameter]
55+
) -> bool:
56+
"""Return `True` if the given `node` has static data, or follows after a `Dequantize` node with a static input.
57+
In the IR, the `node` will be turned into a static quantized tensor.
58+
:param node: Tensor node to check for data.
59+
:param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the
60+
`state_dict` attribute of an edge program.
61+
"""
62+
if node_is_static_tensor(node, parameters_mapping):
63+
return True
64+
65+
def _is_dequantize(node_: Node) -> bool:
66+
return node_.target.__name__ in {
67+
"quantized_decomposed.dequantize_per_tensor.default",
68+
"quantized_decomposed.dequantize_per_channel.default",
69+
}
70+
71+
return _is_dequantize(node) and node_is_static_tensor(
72+
node.args[0], parameters_mapping
73+
)

0 commit comments

Comments
 (0)