Skip to content

Commit 1f5ab1c

Browse files
hsharma35facebook-github-bot
authored andcommitted
Add pass to convert kwargs to args + populate optional args. (#10857)
Summary: Adds a pass to convert kwargs to args and populate optional args. Reviewed By: zonglinpeng Differential Revision: D74510388
1 parent f39e694 commit 1f5ab1c

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

backends/cadence/aot/simplify_ops.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
CadencePassAttribute,
1717
register_cadence_pass,
1818
)
19-
2019
from executorch.exir.dialects._ops import ops as exir_ops
20+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2121
from executorch.exir.pass_base import ExportPass, ProxyValue
22+
from torch.fx.operator_schemas import get_signature_for_torch_op
2223

2324

2425
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -109,8 +110,43 @@ def call_operator(self, op, args, kwargs, meta):
109110
return super().call_operator(op, new_args, kwargs, meta)
110111

111112

113+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
114+
class BindOptionalArgsPass(ExportPass):
115+
"""Bind all optional args and kwargs."""
116+
117+
def call_operator(self, op, args, kwargs, meta):
118+
if not isinstance(op, EdgeOpOverload):
119+
return super().call_operator(op, args, kwargs, meta)
120+
assert callable(op)
121+
122+
torch_op_schemas = get_signature_for_torch_op(op._op)
123+
if len(torch_op_schemas) == 0:
124+
return super().call_operator(op, args, kwargs, meta)
125+
126+
matched_schemas = []
127+
# Iterate through all of the schema until we find one that matches
128+
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
129+
# values. If none matches, `new_args_and_kwargs` will be None
130+
for candidate_signature in torch_op_schemas:
131+
try:
132+
candidate_signature.bind(*args, **kwargs)
133+
matched_schemas.append(candidate_signature)
134+
except TypeError:
135+
continue
136+
137+
if len(matched_schemas) != 1:
138+
# Did not match any schema. Cannot normalize
139+
return super().call_operator(op, args, kwargs, meta)
140+
141+
sig = matched_schemas[0]
142+
bound_args = sig.bind(*args, **kwargs)
143+
144+
return super().call_operator(op, bound_args.args, bound_args.kwargs, meta)
145+
146+
112147
# This class encapsulates all the functions that simplify the op's args
113148
class CadenceSimplifyOpsInGraph:
114149
passes = [
115150
SimplifySliceOpPass,
151+
BindOptionalArgsPass,
116152
]

0 commit comments

Comments
 (0)