|
16 | 16 | CadencePassAttribute, |
17 | 17 | register_cadence_pass, |
18 | 18 | ) |
| 19 | +from executorch.backends.cadence.aot.utils import rebind |
19 | 20 | from executorch.exir.dialects._ops import ops as exir_ops |
20 | 21 | from executorch.exir.dialects.edge._ops import EdgeOpOverload |
21 | 22 | from executorch.exir.pass_base import ExportPass, ProxyValue |
22 | | -from torch.fx.operator_schemas import get_signature_for_torch_op |
23 | 23 |
|
24 | 24 |
|
25 | 25 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
@@ -117,32 +117,11 @@ class BindOptionalArgsPass(ExportPass): |
117 | 117 | def call_operator(self, op, args, kwargs, meta): |
118 | 118 | if not isinstance(op, EdgeOpOverload): |
119 | 119 | return super().call_operator(op, args, kwargs, meta) |
120 | | - assert callable(op) |
121 | 120 |
|
122 | | - torch_op_schemas = get_signature_for_torch_op(op._op) |
123 | | - if len(torch_op_schemas) == 0: |
124 | | - return super().call_operator(op, args, kwargs, meta) |
125 | | - |
126 | | - matched_schemas = [] |
127 | | - # Iterate through all of the schema until we find one that matches |
128 | | - # If one matches, populate `new_args_and_kwargs` with the new args/kwargs |
129 | | - # values. If none matches, `new_args_and_kwargs` will be None |
130 | | - for candidate_signature in torch_op_schemas: |
131 | | - try: |
132 | | - candidate_signature.bind(*args, **kwargs) |
133 | | - matched_schemas.append(candidate_signature) |
134 | | - except TypeError: |
135 | | - continue |
136 | | - |
137 | | - if len(matched_schemas) != 1: |
138 | | - # Did not match any schema. Cannot normalize |
139 | | - return super().call_operator(op, args, kwargs, meta) |
140 | | - |
141 | | - sig = matched_schemas[0] |
142 | | - bound_args = sig.bind(*args, **kwargs) |
143 | | - bound_args.apply_defaults() |
| 121 | + if (updated_args := rebind(op, args, kwargs)) is not None: |
| 122 | + args, kwargs = updated_args |
144 | 123 |
|
145 | | - return super().call_operator(op, bound_args.args, bound_args.kwargs, meta) |
| 124 | + return super().call_operator(op, args, kwargs, meta) |
146 | 125 |
|
147 | 126 |
|
148 | 127 | # This class encapsulates all the functions that simplify the op's args |
|
0 commit comments