Skip to content

Commit bfd2634

Browse files
authored
Merge branch 'main' into jz/add-thinking-toggle
2 parents 0de4f59 + 087fe59 commit bfd2634

File tree

78 files changed

+1519
-539
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+1519
-539
lines changed

.ci/scripts/build-qnn-sdk.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ set_up_aot() {
3333
cmake .. \
3434
-DCMAKE_INSTALL_PREFIX=$PWD \
3535
-DEXECUTORCH_BUILD_QNN=ON \
36+
-DANDROID_NATIVE_API_LEVEL=30 \
3637
-DQNN_SDK_ROOT=${QNN_SDK_ROOT} \
3738
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
3839
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \

Package.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,8 @@ let package = Package(
9696
.copy("resources/add.pte")
9797
],
9898
linkerSettings: [
99-
.linkedLibrary("c++"),
10099
.unsafeFlags([
101-
"-Xlinker", "-force_load",
102-
"-Xlinker", "cmake-out/kernels_portable.xcframework/macos-arm64/libkernels_portable_macos.a",
100+
"-Xlinker", "-all_load",
103101
])
104102
]
105103
)

backends/apple/mps/setup.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ cd executorch
7676
## Run the mv3 generated model using the mps_executor_runner
7777

7878
```bash
79-
./cmake-out/examples/apple/mps/mps_executor_runner --model_path mv3_mps_bundled_fp16.pte --bundled_program
79+
./cmake-out/examples/apple/mps/mps_executor_runner --model_path mv3_mps_float16_bundled.pte --bundled_program
8080
```
8181

8282
- You should see the following results. Note that no output file will be generated in this example:
8383
```
84-
I 00:00:00.003290 executorch:mps_executor_runner.mm:286] Model file mv3_mps_bundled_fp16.pte is loaded.
84+
I 00:00:00.003290 executorch:mps_executor_runner.mm:286] Model file mv3_mps_float16_bundled.pte is loaded.
8585
I 00:00:00.003306 executorch:mps_executor_runner.mm:292] Program methods: 1
8686
I 00:00:00.003308 executorch:mps_executor_runner.mm:294] Running method forward
8787
I 00:00:00.003311 executorch:mps_executor_runner.mm:349] Setting up non-const buffer 1, size 606112.
@@ -118,7 +118,7 @@ python3 -m examples.apple.mps.scripts.mps_example --model_name="mv3" --generate_
118118
```
119119
2. Run your Program on the ExecuTorch runtime and generate an [ETDump](../../../docs/source/etdump.md).
120120
```
121-
./cmake-out/examples/apple/mps/mps_executor_runner --model_path mv3_mps_bundled_fp16.pte --bundled_program --dump-outputs
121+
./cmake-out/examples/apple/mps/mps_executor_runner --model_path mv3_mps_float16_bundled.pte --bundled_program --dump-outputs
122122
```
123123
3. Create an instance of the Inspector API by passing in the ETDump you have sourced from the runtime along with the optionally generated ETRecord from step 1.
124124
```bash

backends/cadence/aot/compiler_utils.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,12 @@ def get_cascaded_ops(
109109
return nodes
110110

111111

112-
# Capture the effect of transpose op on incoming dimension order
113-
def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
112+
def get_transposed_dims(
113+
node: torch.fx.Node, dims: Optional[List[int]] = None
114+
) -> List[int]:
114115
"""
115-
Given a transpose node, and the incoming dimension ordering of the input
116-
tensor to the transpose node, return the net effect of transpose op on the
117-
dimension order.
116+
Applies the transposition as given by node onto the dimensions given in input
117+
e.g (1, 2) on [a, b, c, d] would return [a, c, b, d]
118118
"""
119119
assert node.target == exir_ops.edge.aten.transpose_copy.int
120120
# Assert that the dims is not empty
@@ -127,28 +127,22 @@ def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
127127
assert isinstance(transpose_dims1, int)
128128
dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len
129129
dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len
130-
# Perform transpose on dimmension ordering (dims)
131-
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
132-
return dims
130+
new_dims = list(dims)
131+
new_dims[dim0], new_dims[dim1] = dims[dim1], dims[dim0]
132+
return new_dims
133133

134134

135-
# Capture the effect of permute op on incoming dimension order
136-
def get_permuted_dims(node: torch.fx.Node, dims: Optional[Sequence[int]]) -> List[int]:
135+
def get_permuted_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
137136
"""
138-
Given a permute node, and the incoming dimension ordering of the input
139-
tensor to the permute node, return the net effect of permute op on the
140-
dimension order.
137+
Applies the permutation as given by node onto the dimensions given in input
138+
e.g (1, 2, 0) on [a, b, c] would return [b, c, a]
141139
"""
142140
assert node.target == exir_ops.edge.aten.permute_copy.default
143141
# Permute each index of the dimension ordering (dims)
144142
# pyre-fixme[6]: This combined typecheck isn't supported yet.
145143
permute_dims: List[int] = list(node.args[1])
146144
assert all(isinstance(x, int) for x in permute_dims)
147-
# If the dims is empty, we can simply return the permute order
148-
if not dims:
149-
return permute_dims
150-
dims = [dims[x] for x in permute_dims]
151-
return dims
145+
return [dims[x] for x in permute_dims]
152146

153147

154148
# Return the tensor of buffer/parameter op

backends/cadence/aot/fuse_ops.py

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import operator
1515
from collections import deque
1616
from numbers import Number
17-
from typing import cast, Sequence
17+
from typing import Any, Callable, cast
1818

1919
# Import these for the cadence function signatures.
2020
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
@@ -881,9 +881,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
881881

882882

883883
@register_cadence_pass(CadencePassAttribute(opt_level=1))
884-
class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
884+
class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885
"""
886-
Fuse transpose op pairs to a single view op.
886+
Fuse transpose or permute op pairs to a single view op.
887+
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
887888
"""
888889

889890
# A list of ops that can be bypassed when looking for a
@@ -907,42 +908,17 @@ def can_fuse_for_chain(
907908
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
908909
return False
909910

910-
def get_dims(node: torch.fx.Node) -> tuple[int, int]:
911-
def canonicalize(dim: int) -> int:
912-
if dim < 0:
913-
dim += len(node.meta["val"].shape)
914-
return dim
915-
916-
return tuple(canonicalize(cast(int, d)) for d in node.args[1:3])
917-
918-
def is_equivalent(
919-
shape: Sequence[int],
920-
transpose0: tuple[int, int],
921-
transpose1: tuple[int, int],
922-
) -> bool:
923-
def permute_order(
924-
order: Sequence[int], dims: tuple[int, int]
925-
) -> Sequence[int]:
926-
new_order = list(order)
927-
new_order[dims[0]], new_order[dims[1]] = (
928-
new_order[dims[1]],
929-
new_order[dims[0]],
930-
)
931-
return new_order
932-
933-
order = permute_order(range(len(shape)), transpose0)
934-
order = permute_order(order, transpose1)
935-
936-
non_unit_dims = [dim for dim in range(len(shape)) if shape[dim] != 1]
937-
non_unit_dims_permuted = [dim for dim in order if shape[dim] != 1]
938-
939-
return non_unit_dims == non_unit_dims_permuted
940-
941-
return is_equivalent(
942-
cast(torch.fx.Node, producer.args[0]).meta["val"].shape,
943-
get_dims(producer),
944-
get_dims(consumer),
945-
)
911+
# checking that permut2(permut1(identify)) == identity
912+
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
913+
ident_dims = list(range(len(input_shape)))
914+
# this mapping helps to handle both transpose and permutations
915+
f: dict[Any, Callable] = {
916+
exir_ops.edge.aten.transpose_copy.int: get_transposed_dims,
917+
exir_ops.edge.aten.permute_copy.default: get_permuted_dims,
918+
}
919+
in_dims = f[producer.target](producer, ident_dims)
920+
out_dims = f[consumer.target](consumer, in_dims)
921+
return out_dims == ident_dims
946922

947923
def get_fused_node(
948924
self,
@@ -960,11 +936,17 @@ def get_fused_node(
960936
return view
961937

962938
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
963-
# Remove any dequantize op that has only quantize ops as its users.
939+
# Remove any transpose/permutation op pair that cancel each other.
964940
self.find_and_fuse(
965941
graph_module,
966-
producer_op_packets={exir_ops.edge.aten.transpose_copy},
967-
consumer_op_packets={exir_ops.edge.aten.transpose_copy},
942+
producer_op_packets={
943+
exir_ops.edge.aten.transpose_copy,
944+
exir_ops.edge.aten.permute_copy,
945+
},
946+
consumer_op_packets={
947+
exir_ops.edge.aten.transpose_copy,
948+
exir_ops.edge.aten.permute_copy,
949+
},
968950
bypass_ops=self.bypass_ops,
969951
)
970952
result = super().call(graph_module)
@@ -1028,5 +1010,5 @@ class CadenceFuseOpsInGraph:
10281010
FuseQuantDequantToRequantizePass,
10291011
FuseMulIntoDequantPass,
10301012
FuseFullThenReshapePass,
1031-
FuseTransposeOpPairsPass,
1013+
FuseTransposeOrPermuteOpPairsPass,
10321014
]

backends/cadence/aot/passes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.backends.cadence.aot.fuse_ops import (
1515
CadenceFuseOpsInGraph,
1616
FuseFullThenReshapePass,
17-
FuseTransposeOpPairsPass,
17+
FuseTransposeOrPermuteOpPairsPass,
1818
)
1919
from executorch.backends.cadence.aot.pass_utils import (
2020
CadencePassAttribute,
@@ -83,7 +83,7 @@ def get_passes_in_default_order() -> List[ExportPass]:
8383
CadenceSimplifyOpsInGraph.passes,
8484
FinalizePipeline,
8585
FuseFullThenReshapePass,
86-
FuseTransposeOpPairsPass,
86+
FuseTransposeOrPermuteOpPairsPass,
8787
RemoveNopSliceOrViewOpPass,
8888
]
8989
return pytree.tree_flatten(passes)[0]

backends/cadence/aot/replace_ops.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2263,9 +2263,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
22632263

22642264

22652265
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2266-
class ReplacePowWithMullPass(ExportPass):
2266+
class ReplacePowWithMulPass(ExportPass):
22672267
"""
2268-
Replace the pow op with degree 2 for a mul op.
2268+
Replace the pow op for a mul op.
22692269
"""
22702270

22712271
def call_operator(
@@ -2275,19 +2275,32 @@ def call_operator(
22752275
kwargs: Dict[str, Argument],
22762276
meta: NodeMetadata,
22772277
) -> ProxyValue:
2278-
# TODO(eigen): Add support for other degrees.
2279-
if (
2280-
op
2281-
not in {
2282-
exir_ops.edge.aten.pow.Scalar,
2278+
if not (
2279+
len(args) > 1
2280+
and isinstance(args[1], int)
2281+
and cast(int, args[1]) > 1
2282+
and cast(int, args[1]) < 5
2283+
and op
2284+
in {
2285+
exir_ops.edge.aten.pow.Tensor_Scalar,
22832286
}
2284-
or args[0] != 2
22852287
):
22862288
return super().call_operator(op, args, kwargs, meta)
22872289

2290+
x = args[0]
2291+
exponent = cast(int, args[1])
2292+
2293+
if exponent > 2:
2294+
for _ in range(exponent, 2, -1):
2295+
x = super().call_operator(
2296+
exir_ops.edge.aten.mul.Tensor,
2297+
(x, args[0]),
2298+
{},
2299+
meta,
2300+
)
22882301
return super().call_operator(
22892302
exir_ops.edge.aten.mul.Tensor,
2290-
(args[1], args[1]),
2303+
(x, args[0]),
22912304
{},
22922305
meta,
22932306
)
@@ -2429,5 +2442,5 @@ class CadenceReplaceOpsInGraph:
24292442
ReplaceWhereWithFullArgsWithWhereScalar,
24302443
ReplaceGeluWithApproximateGeluPass,
24312444
ReplaceSplitWithSlicePass,
2432-
ReplacePowWithMullPass,
2445+
ReplacePowWithMulPass,
24332446
]

0 commit comments

Comments
 (0)