Skip to content

Commit 792c964

Browse files
committed
Merge remote-tracking branch 'origin/main' into android-backend-used-by-method
2 parents a833d98 + d069d65 commit 792c964

File tree

29 files changed

+199
-131
lines changed

29 files changed

+199
-131
lines changed

CMakeLists.txt

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ project(executorch)
4848
# MARK: - Start EXECUTORCH_H12025_BUILD_MIGRATION --------------------------------------------------
4949

5050
include(${PROJECT_SOURCE_DIR}/tools/cmake/common/preset.cmake)
51+
include(${PROJECT_SOURCE_DIR}/tools/cmake/Utils.cmake)
52+
include(CMakeDependentOption)
53+
include(ExternalProject)
5154

5255
if(NOT CMAKE_CXX_STANDARD)
5356
set(CMAKE_CXX_STANDARD 17)
@@ -64,10 +67,14 @@ if(NOT CMAKE_BUILD_TYPE)
6467
endif()
6568
announce_configured_options(CMAKE_BUILD_TYPE)
6669

70+
if(NOT PYTHON_EXECUTABLE)
71+
resolve_python_executable()
72+
endif()
73+
announce_configured_options(PYTHON_EXECUTABLE)
74+
6775
announce_configured_options(CMAKE_CXX_COMPILER_ID)
6876
announce_configured_options(CMAKE_TOOLCHAIN_FILE)
6977
announce_configured_options(BUCK2)
70-
announce_configured_options(PYTHON_EXECUTABLE)
7178

7279
load_build_preset()
7380
include(${PROJECT_SOURCE_DIR}/tools/cmake/preset/default.cmake)
@@ -77,10 +84,6 @@ print_configured_options()
7784

7885
# MARK: - End EXECUTORCH_H12025_BUILD_MIGRATION ----------------------------------------------------
7986

80-
include(tools/cmake/Utils.cmake)
81-
include(CMakeDependentOption)
82-
include(ExternalProject)
83-
8487
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
8588

8689
# Setup RPATH.
@@ -256,11 +259,6 @@ if(EXECUTORCH_BUILD_TESTS)
256259
include(CTest)
257260
endif()
258261

259-
if(NOT PYTHON_EXECUTABLE)
260-
resolve_python_executable()
261-
endif()
262-
message(STATUS "Using python executable '${PYTHON_EXECUTABLE}'")
263-
264262
# TODO(dbort): Fix these warnings and remove this flag.
265263
set(_common_compile_options -Wno-deprecated-declarations -fPIC)
266264

backends/apple/mps/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ endif()
1818

1919
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2020

21-
if(NOT PYTHON_EXECUTABLE)
22-
resolve_python_executable()
23-
endif()
24-
2521
set(_common_compile_options -Wno-deprecated-declarations)
2622
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
2723

backends/cadence/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@ add_compile_definitions(C10_USING_CUSTOM_GENERATED_MACROS)
3030
if(EXECUTORCH_CADENCE_CPU_RUNNER)
3131
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
3232

33-
if(NOT PYTHON_EXECUTABLE)
34-
resolve_python_executable()
35-
endif()
36-
3733
set(_common_compile_options -Wno-deprecated-declarations -fPIC)
3834

3935
# Find prebuilt libraries. executorch package should contain portable_ops_lib,

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ python_unittest(
367367
"fbsource//third-party/pypi/parameterized:parameterized",
368368
"//caffe2:torch",
369369
"//executorch/backends/cadence/aot:compiler",
370+
"//executorch/backends/cadence/aot:graph_builder",
370371
"//executorch/backends/cadence/aot:ops_registrations",
371372
"//executorch/backends/cadence/aot:pass_utils",
372373
"//executorch/backends/cadence/aot:simplify_ops",

backends/cadence/aot/graph_builder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ def call_submodule(
9696
) -> PassResult:
9797
return ExportPass().call(graph_module)
9898

99+
def call_getitem(
100+
self, value: ProxyValue, key: int, meta: Optional[NodeMetadata] = None
101+
) -> ProxyValue:
102+
return super().call_getitem(value, key, meta or NodeMetadata({}))
103+
99104
def _fx(
100105
self,
101106
kind: str,

backends/cadence/aot/pass_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,34 @@ def nodes_not_adjacent_in_gm(
157157
if node.next.target == succ_target:
158158
return False
159159
return True
160+
161+
162+
def get_arg(
163+
node: torch.fx.Node,
164+
arg_index: int,
165+
kwarg_name: str,
166+
*,
167+
default: torch.fx.node.Argument = None,
168+
) -> torch.fx.node.Argument:
169+
"""
170+
Get the arg at arg_index or kwarg with arg_name of the node. If neither is found
171+
return default.
172+
"""
173+
if arg_index < len(node.args):
174+
return node.args[arg_index]
175+
elif kwarg_name in node.kwargs:
176+
return node.kwargs[kwarg_name]
177+
else:
178+
return default
179+
180+
181+
def set_arg(
182+
node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument
183+
) -> None:
184+
"""
185+
Set the arg at arg_index if it exists, otherwise set the kwarg.
186+
"""
187+
if arg_index < len(node.args):
188+
node.update_arg(arg_index, value)
189+
else:
190+
node.update_kwarg(kwarg_name, value)

backends/cadence/aot/remove_ops.py

Lines changed: 43 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
import torch.fx
2626
from executorch.backends.cadence.aot.pass_utils import (
2727
CadencePassAttribute,
28+
get_arg,
2829
register_cadence_pass,
30+
set_arg,
2931
)
3032

3133
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
@@ -37,7 +39,7 @@
3739
from executorch.exir.pass_manager import PassManager, PassType
3840
from executorch.exir.passes import dead_code_elimination_pass
3941
from executorch.exir.passes.spec_prop_pass import SpecPropPass
40-
from torch.fx.node import Argument
42+
from torch.fx.node import Argument, Node
4143

4244

4345
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -771,65 +773,52 @@ def remove_branched(
771773

772774

773775
class RemoveCatFromSliceCopyPass(ExportPass):
774-
def _remove_unused_cat( # noqa: C901
775-
self, graph_module: torch.fx.GraphModule
776-
) -> None:
777-
slice_copy_nodes = [
778-
node
779-
for node in graph_module.graph.nodes
780-
if node.target == exir_ops.edge.aten.slice_copy.Tensor
781-
]
782-
for slice_copy_node in slice_copy_nodes:
783-
slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1
784-
input_node, *other_args = slice_copy_node.args
785-
if len(other_args) >= 1:
786-
slice_dim = other_args[0]
787-
if len(other_args) >= 2:
788-
start_idx = other_args[1]
789-
if len(other_args) >= 3:
790-
end_idx = other_args[2]
791-
if len(other_args) >= 4:
792-
step = other_args[3]
793-
if step != 1:
794-
continue
795-
slice_copy_dtype = slice_copy_node.meta["val"].dtype
796-
if input_node.target != exir_ops.edge.aten.cat.default:
797-
continue
798-
cat_dtype = input_node.meta["val"].dtype
799-
if slice_copy_dtype != cat_dtype:
776+
"""
777+
Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed
778+
to the slice_copy.
779+
"""
780+
781+
def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
782+
for slice_copy_node in graph_module.graph.find_nodes(
783+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
784+
):
785+
cat_node = cast(Node, get_arg(slice_copy_node, 0, "input"))
786+
slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0))
787+
start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None))
788+
end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None))
789+
step = cast(int, get_arg(slice_copy_node, 4, "step", default=1))
790+
791+
if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
800792
continue
801-
cat_dim = input_node.args[1:]
802-
if len(cat_dim) == 0:
803-
cat_dim = 0
793+
794+
# Make sure cat and slice happens on the same dimension.
795+
cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0))
804796
if cat_dim != slice_dim:
805797
continue
806-
cat_output_shape = input_node.meta["val"].shape
807-
start_idx = (
808-
cat_output_shape[cat_dim] + start_idx if start_idx < 0 else start_idx
809-
)
810-
end_idx = (
811-
cat_output_shape[cat_dim]
812-
if end_idx > cat_output_shape[cat_dim]
813-
else end_idx
814-
)
815-
base_idx = 0
816-
cat_input_to_keep = None
817-
for cat_input_node in input_node.args[0]:
818-
cat_input_dtype = cat_input_node.meta["val"].dtype
819-
if slice_copy_dtype != cat_input_dtype:
820-
continue
798+
799+
# Canonicalize slice indices.
800+
cat_output_shape = cat_node.meta["val"].shape
801+
if start_idx is None:
802+
start_idx = 0
803+
elif start_idx < 0:
804+
start_idx += cat_output_shape[cat_dim]
805+
if end_idx is None or end_idx > cat_output_shape[cat_dim]:
806+
end_idx = cat_output_shape[cat_dim]
807+
elif end_idx < 0:
808+
end_idx += cat_output_shape[cat_dim]
809+
810+
offset = 0
811+
for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")):
821812
cat_input_shape = cat_input_node.meta["val"].shape
822813

823-
# check if the slice range overlaps with the cat range
824-
if (
825-
base_idx <= start_idx
826-
and end_idx <= list(cat_input_shape)[cat_dim] + base_idx
827-
):
828-
cat_input_to_keep = cat_input_node
814+
# Check if the slice range overlaps with the cat input range.
815+
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
816+
slice_copy_node.replace_input_with(cat_node, cat_input_node)
817+
set_arg(slice_copy_node, 2, "start", start_idx - offset)
818+
set_arg(slice_copy_node, 3, "end", end_idx - offset)
829819
break
830-
base_idx += list(cat_input_shape)[cat_dim]
831-
if cat_input_to_keep is not None:
832-
slice_copy_node.replace_input_with(input_node, cat_input_to_keep)
820+
821+
offset += cat_input_shape[cat_dim]
833822

834823
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
835824
self._remove_unused_cat(graph_module)

backends/cadence/aot/simplify_ops.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
CadencePassAttribute,
1717
register_cadence_pass,
1818
)
19-
2019
from executorch.exir.dialects._ops import ops as exir_ops
20+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2121
from executorch.exir.pass_base import ExportPass, ProxyValue
22+
from torch.fx.operator_schemas import get_signature_for_torch_op
2223

2324

2425
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -109,8 +110,44 @@ def call_operator(self, op, args, kwargs, meta):
109110
return super().call_operator(op, new_args, kwargs, meta)
110111

111112

113+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
114+
class BindOptionalArgsPass(ExportPass):
115+
"""Bind all optional args and kwargs."""
116+
117+
def call_operator(self, op, args, kwargs, meta):
118+
if not isinstance(op, EdgeOpOverload):
119+
return super().call_operator(op, args, kwargs, meta)
120+
assert callable(op)
121+
122+
torch_op_schemas = get_signature_for_torch_op(op._op)
123+
if len(torch_op_schemas) == 0:
124+
return super().call_operator(op, args, kwargs, meta)
125+
126+
matched_schemas = []
127+
# Iterate through all of the schema until we find one that matches
128+
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
129+
# values. If none matches, `new_args_and_kwargs` will be None
130+
for candidate_signature in torch_op_schemas:
131+
try:
132+
candidate_signature.bind(*args, **kwargs)
133+
matched_schemas.append(candidate_signature)
134+
except TypeError:
135+
continue
136+
137+
if len(matched_schemas) != 1:
138+
# Did not match any schema. Cannot normalize
139+
return super().call_operator(op, args, kwargs, meta)
140+
141+
sig = matched_schemas[0]
142+
bound_args = sig.bind(*args, **kwargs)
143+
bound_args.apply_defaults()
144+
145+
return super().call_operator(op, bound_args.args, bound_args.kwargs, meta)
146+
147+
112148
# This class encapsulates all the functions that simplify the op's args
113149
class CadenceSimplifyOpsInGraph:
114150
passes = [
115151
SimplifySliceOpPass,
152+
BindOptionalArgsPass,
116153
]

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,30 @@ def forward(self, x, y):
864864

865865
# Ensure both cat nodes were removed
866866
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
867+
868+
def test_remove_cat_from_slice_copy_second_input(self) -> None:
869+
builder = GraphBuilder()
870+
x = builder.placeholder("x", torch.randn(2, 4))
871+
y = builder.placeholder("y", torch.randn(2, 4))
872+
cat = builder.call_operator(
873+
op=exir_ops.edge.aten.cat.default,
874+
args=((x, y), 1),
875+
)
876+
slice_copy = builder.call_operator(
877+
op=exir_ops.edge.aten.slice_copy.Tensor,
878+
args=(cat, 1, 5, 7, 1),
879+
)
880+
builder.output([slice_copy])
881+
graph_module = builder.get_graph_module()
882+
883+
inputs = (torch.randn(2, 4), torch.randn(2, 4))
884+
expected_outputs = graph_module(*inputs)[0]
885+
886+
p = RemoveCatFromSliceCopyPass()
887+
graph_module = cast(PassResult, p(graph_module)).graph_module
888+
889+
# Cat should be removed.
890+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
891+
892+
# Output should remain the same.
893+
self.assertTrue(torch.equal(graph_module(*inputs)[0], expected_outputs))

backends/cadence/aot/tests/test_simplify_ops_passes.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
1515
from executorch.backends.cadence.aot.compiler import export_to_edge
16+
from executorch.backends.cadence.aot.graph_builder import single_op_builder
1617
from executorch.backends.cadence.aot.pass_utils import count_node
17-
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
18+
from executorch.backends.cadence.aot.simplify_ops import (
19+
BindOptionalArgsPass,
20+
SimplifySliceOpPass,
21+
)
1822
from executorch.exir.dialects._ops import ops as exir_ops
1923
from parameterized.parameterized import parameterized
2024
from torch.fx.passes.infra.pass_base import PassResult
@@ -112,3 +116,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
112116
self.assertEqual(
113117
count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1
114118
)
119+
120+
def test_simplify_slice_op_args(self) -> None:
121+
x = torch.rand(4, 5)
122+
gm = single_op_builder(
123+
placeholders=(x,),
124+
op=exir_ops.edge.aten.slice_copy.Tensor,
125+
args=(x, 1),
126+
kwargs={"end": 3},
127+
)
128+
self.assertEqual(
129+
[
130+
(n.args[1:], n.kwargs)
131+
for n in gm.graph.find_nodes(
132+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
133+
)
134+
],
135+
[((1,), {"end": 3})],
136+
)
137+
138+
gm = BindOptionalArgsPass().call(gm).graph_module
139+
140+
self.assertEqual(
141+
[
142+
(n.args[1:], n.kwargs)
143+
for n in gm.graph.find_nodes(
144+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
145+
)
146+
],
147+
[((1, None, 3, 1), {})],
148+
)

0 commit comments

Comments
 (0)