Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ python_unittest(
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:simplify_ops",
Expand Down
5 changes: 5 additions & 0 deletions backends/cadence/aot/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def call_submodule(
) -> PassResult:
return ExportPass().call(graph_module)

def call_getitem(
self, value: ProxyValue, key: int, meta: Optional[NodeMetadata] = None
) -> ProxyValue:
return super().call_getitem(value, key, meta or NodeMetadata({}))

def _fx(
self,
kind: str,
Expand Down
39 changes: 38 additions & 1 deletion backends/cadence/aot/simplify_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
CadencePassAttribute,
register_cadence_pass,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, ProxyValue
from torch.fx.operator_schemas import get_signature_for_torch_op


@register_cadence_pass(CadencePassAttribute(opt_level=0))
Expand Down Expand Up @@ -109,8 +110,44 @@ def call_operator(self, op, args, kwargs, meta):
return super().call_operator(op, new_args, kwargs, meta)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class BindOptionalArgsPass(ExportPass):
"""Bind all optional args and kwargs."""

def call_operator(self, op, args, kwargs, meta):
if not isinstance(op, EdgeOpOverload):
return super().call_operator(op, args, kwargs, meta)
assert callable(op)

torch_op_schemas = get_signature_for_torch_op(op._op)
if len(torch_op_schemas) == 0:
return super().call_operator(op, args, kwargs, meta)

matched_schemas = []
# Iterate through all of the schema until we find one that matches
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
# values. If none matches, `new_args_and_kwargs` will be None
for candidate_signature in torch_op_schemas:
try:
candidate_signature.bind(*args, **kwargs)
matched_schemas.append(candidate_signature)
except TypeError:
continue

if len(matched_schemas) != 1:
# Did not match any schema. Cannot normalize
return super().call_operator(op, args, kwargs, meta)

sig = matched_schemas[0]
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()

return super().call_operator(op, bound_args.args, bound_args.kwargs, meta)


# This class encapsulates all the functions that simplify the op's args
class CadenceSimplifyOpsInGraph:
passes = [
SimplifySliceOpPass,
BindOptionalArgsPass,
]
36 changes: 35 additions & 1 deletion backends/cadence/aot/tests/test_simplify_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot.compiler import export_to_edge
from executorch.backends.cadence.aot.graph_builder import single_op_builder
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
from executorch.backends.cadence.aot.simplify_ops import (
BindOptionalArgsPass,
SimplifySliceOpPass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from parameterized.parameterized import parameterized
from torch.fx.passes.infra.pass_base import PassResult
Expand Down Expand Up @@ -112,3 +116,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1
)

def test_simplify_slice_op_args(self) -> None:
x = torch.rand(4, 5)
gm = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(x, 1),
kwargs={"end": 3},
)
self.assertEqual(
[
(n.args[1:], n.kwargs)
for n in gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
],
[((1,), {"end": 3})],
)

gm = BindOptionalArgsPass().call(gm).graph_module

self.assertEqual(
[
(n.args[1:], n.kwargs)
for n in gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
],
[((1, None, 3, 1), {})],
)
Loading