Skip to content

Commit 6e35bd1

Browse files
authored
Merge branch 'main' into export-D82560656
2 parents 999c6ce + 641e737 commit 6e35bd1

File tree

107 files changed

+2590
-185
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+2590
-185
lines changed

.ci/scripts/test_model.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,13 @@ test_model_with_xnnpack() {
131131
return 0
132132
fi
133133

134-
# Delegation
134+
# Delegation and test with pybindings
135135
if [[ ${WITH_QUANTIZATION} == true ]]; then
136136
SUFFIX="q8"
137-
"${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate --quantize
137+
"${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate --quantize --test_after_export
138138
else
139139
SUFFIX="fp32"
140-
"${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate
140+
"${PYTHON_EXECUTABLE}" -m examples.xnnpack.aot_compiler --model_name="${MODEL_NAME}" --delegate --test_after_export
141141
fi
142142

143143
OUTPUT_MODEL_PATH="${MODEL_NAME}_xnnpack_${SUFFIX}.pte"

.ci/scripts/test_wheel_package_qnn.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ run_core_tests () {
145145
echo "=== [$LABEL] Import smoke tests ==="
146146
"$PYBIN" -c "import executorch; print('executorch imported successfully')"
147147
"$PYBIN" -c "import executorch.backends.qualcomm; print('executorch.backends.qualcomm imported successfully')"
148+
"$PYBIN" -c "from executorch.export.target_recipes import get_android_recipe; recipe = get_android_recipe('android-arm64-snapdragon-fp16'); print(f'executorch.export.target_recipes imported successfully: {recipe}')"
148149

149150
echo "=== [$LABEL] List installed executorch/backends/qualcomm/python ==="
150151
local SITE_DIR

.ci/scripts/wheel/test_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,18 @@ class ModelTest:
4141

4242

4343
def run_tests(model_tests: List[ModelTest]) -> None:
44+
# Test that we can import the portable_lib module - verifies RPATH is correct
45+
print("Testing portable_lib import...")
46+
try:
47+
from executorch.extension.pybindings._portable_lib import ( # noqa: F401
48+
_load_for_executorch,
49+
)
50+
51+
print("✓ Successfully imported _load_for_executorch from portable_lib")
52+
except ImportError as e:
53+
print(f"✗ Failed to import portable_lib: {e}")
54+
raise
55+
4456
# Why are we doing this envvar shenanigans? Since we build the testers, which
4557
# uses buck, we cannot run as root. This is a sneaky of getting around that
4658
# test.

CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,21 @@ if(EXECUTORCH_BUILD_PYBIND)
869869
target_compile_options(portable_lib PUBLIC ${_pybind_compile_options})
870870
target_link_libraries(portable_lib PRIVATE ${_dep_libs})
871871

872+
# Set RPATH to find PyTorch libraries relative to the installation location
873+
# This goes from executorch/extension/pybindings up to site-packages, then to
874+
# torch/lib
875+
if(APPLE)
876+
set_target_properties(
877+
portable_lib PROPERTIES BUILD_RPATH "@loader_path/../../../torch/lib"
878+
INSTALL_RPATH "@loader_path/../../../torch/lib"
879+
)
880+
else()
881+
set_target_properties(
882+
portable_lib PROPERTIES BUILD_RPATH "$ORIGIN/../../../torch/lib"
883+
INSTALL_RPATH "$ORIGIN/../../../torch/lib"
884+
)
885+
endif()
886+
872887
install(
873888
TARGETS portable_lib
874889
EXPORT ExecuTorchTargets

backends/arm/_passes/add_bias_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import Set, Type
7+
68
import torch
79
from executorch.backends.arm._passes import ArmPass
810
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
911
from executorch.backends.transforms.utils import create_constant_placeholder
1012

1113
from executorch.exir.dialects._ops import ops as exir_ops
12-
from executorch.exir.pass_base import PassResult
14+
from executorch.exir.pass_base import ExportPass, PassResult
1315
from torch.export.graph_signature import InputKind
1416

1517

@@ -19,6 +21,8 @@ class AddBiasPass(ArmPass):
1921
The bias is set to zero.
2022
"""
2123

24+
_passes_required_after: Set[Type[ExportPass]] = set()
25+
2226
targeted_ops = (exir_ops.edge.aten.convolution.default,)
2327

2428
def call(self, graph_module):

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import itertools
99
import operator
10-
from typing import cast, List
10+
from typing import cast, List, Set, Type
1111

1212
import torch
1313
from executorch.backends.arm._passes.arm_pass_utils import create_node
@@ -29,6 +29,8 @@ class AnnotateDecomposedMatmulPass(ExportPass):
2929
matmul-op (can be mm or bmm).
3030
"""
3131

32+
_passes_required_after: Set[Type[ExportPass]] = set()
33+
3234
def _match_partition_to_node(
3335
self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node]
3436
) -> torch.fx.Node:

backends/arm/_passes/annotate_output_dim_order_pass.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
7+
from typing import Set, Type
8+
69
from executorch.backends.arm._passes import ArmPass
710
from executorch.backends.arm._passes.arm_pass_utils import get_output_dim_orders
8-
from executorch.exir.pass_base import PassResult
11+
from executorch.exir.pass_base import ExportPass, PassResult
912

1013

1114
class AnnotateOutputDimOrderPass(ArmPass):
@@ -14,6 +17,8 @@ class AnnotateOutputDimOrderPass(ArmPass):
1417
for verifying that the dim order does not change unexpectedly in later passes.
1518
"""
1619

20+
_passes_required_after: Set[Type[ExportPass]] = set()
21+
1722
def call(self, graph_module):
1823
output_node = graph_module.graph.output_node()
1924
output_node.meta["original_dim_orders"] = get_output_dim_orders(graph_module)

backends/arm/_passes/arm_pass.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
# pyre-unsafe
77

88
import traceback
9-
from typing import Optional
9+
from abc import abstractmethod
10+
from typing import List, Optional, Set, Type
1011

1112
import torch
1213
from executorch.exir.pass_base import ExportPass, NodeMetadata
@@ -19,6 +20,36 @@ def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = No
1920
super(ArmPass, self).__init__()
2021
self.exported_program = exported_program
2122

23+
@property
24+
@abstractmethod
25+
def _passes_required_after(self) -> Set[Type[ExportPass]]:
26+
"""The subclass defines passes that must run after it"""
27+
pass
28+
29+
@staticmethod
30+
def get_required_passes(pass_) -> List[str]:
31+
"""
32+
Returns the list of passes that must be run after this pass, sorted by name.
33+
"""
34+
if hasattr(pass_, "_passes_required_after"):
35+
return sorted([ArmPass.get_name(p) for p in pass_._passes_required_after])
36+
else:
37+
return []
38+
39+
@staticmethod
40+
def get_name(pass_) -> str:
41+
"""
42+
Returns the name of the pass.
43+
"""
44+
if isinstance(pass_, ExportPass):
45+
return pass_.__class__.__name__
46+
elif hasattr(pass_, "__name__"):
47+
return pass_.__name__
48+
else:
49+
raise ValueError(
50+
f"Cannot get name for pass: {pass_}. It must be an instance of ExportPass or have a __name__ attribute."
51+
)
52+
2253
def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
2354
if not updated:
2455
return super().call_operator(op, args, kwargs, meta)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
# pyre-unsafe
99

10+
11+
from collections import defaultdict
12+
1013
import executorch.backends.arm.tosa.dialect # noqa: unused
1114
from executorch.backends.arm._passes import (
1215
AddBiasPass,
@@ -94,6 +97,7 @@
9497
UnsqueezeScalarPlaceholdersPass,
9598
)
9699

100+
from executorch.backends.arm._passes.arm_pass import ArmPass
97101
from executorch.backends.arm.tosa.specification import (
98102
TosaLoweringContext,
99103
TosaSpecification,
@@ -115,6 +119,32 @@ def __init__(self, tosa_spec: TosaSpecification) -> None:
115119
self.tosa_spec = tosa_spec
116120
super().__init__()
117121

122+
def validate_constraints_mandatory(self):
123+
"""
124+
Validates that necessary passes have run before transforming to backend.
125+
126+
Note that this differs from the original validate_constraints function, which
127+
only checks the order of passes.
128+
"""
129+
passes_to_run = defaultdict(list)
130+
131+
for current_pass in self.passes:
132+
current_pass_name = ArmPass.get_name(current_pass)
133+
for required_pass_name in ArmPass.get_required_passes(current_pass):
134+
passes_to_run[required_pass_name].append(current_pass_name)
135+
136+
passes_to_run.pop(current_pass_name, None)
137+
138+
if len(passes_to_run) > 0:
139+
error_msg = "The following constraints for passes are not met:\n"
140+
for required_pass, requiring_passes in passes_to_run.items():
141+
for requiring_pass in requiring_passes:
142+
error_msg += (
143+
f" - {required_pass} must run after {requiring_pass}\n"
144+
)
145+
146+
raise RuntimeError(error_msg)
147+
118148
def _transform(self, graph_module: GraphModule):
119149
with TosaLoweringContext(self.tosa_spec):
120150
return self(graph_module).graph_module
@@ -125,7 +155,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
125155
self.add_pass(RemoveGetItemPass())
126156
self.add_pass(ConvertSplitToSlicePass())
127157
self.add_pass(ConvertMmToBmmPass())
128-
self.add_pass(DecomposeLinearVectorNormPass())
129158
self.add_pass(
130159
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
131160
)
@@ -175,6 +204,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
175204
self.add_pass(RemoveNoopPass())
176205
self.add_pass(InsertRescalePass())
177206

207+
self.validate_constraints_mandatory()
178208
return self._transform(exported_program.graph_module)
179209

180210
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
@@ -258,6 +288,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
258288
self.add_pass(RemoveNoopPass())
259289
self.add_pass(InsertRescalePass())
260290

291+
self.validate_constraints_mandatory()
261292
return self._transform(exported_program.graph_module)
262293

263294
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):

backends/arm/_passes/broadcast_args_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import Set, Type
7+
68
from executorch.backends.arm._passes import ArmPass
79

810
from executorch.backends.arm._passes.arm_pass_utils import (
@@ -12,7 +14,7 @@
1214

1315
from executorch.exir.dialects._ops import ops as exir_ops
1416

15-
from executorch.exir.pass_base import PassResult
17+
from executorch.exir.pass_base import ExportPass, PassResult
1618
from torch.fx import GraphModule, Node
1719

1820

@@ -22,6 +24,8 @@ class BroadcastArgsPass(ArmPass):
2224
This is done when more than one arg needs broadcasting.
2325
"""
2426

27+
_passes_required_after: Set[Type[ExportPass]] = set()
28+
2529
targeted_ops = {
2630
exir_ops.edge.aten.add.Tensor,
2731
exir_ops.edge.aten.sub.Tensor,

0 commit comments

Comments
 (0)