|
16 | 16 | CadencePassAttribute, |
17 | 17 | register_cadence_pass, |
18 | 18 | ) |
19 | | - |
20 | 19 | from executorch.exir.dialects._ops import ops as exir_ops |
| 20 | +from executorch.exir.dialects.edge._ops import EdgeOpOverload |
21 | 21 | from executorch.exir.pass_base import ExportPass, ProxyValue |
| 22 | +from torch.fx.operator_schemas import get_signature_for_torch_op |
22 | 23 |
|
23 | 24 |
|
24 | 25 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
@@ -109,8 +110,43 @@ def call_operator(self, op, args, kwargs, meta): |
109 | 110 | return super().call_operator(op, new_args, kwargs, meta) |
110 | 111 |
|
111 | 112 |
|
| 113 | +@register_cadence_pass(CadencePassAttribute(opt_level=0)) |
| 114 | +class BindOptionalArgsPass(ExportPass): |
| 115 | + """Bind all optional args and kwargs.""" |
| 116 | + |
| 117 | + def call_operator(self, op, args, kwargs, meta): |
| 118 | + if not isinstance(op, EdgeOpOverload): |
| 119 | + return super().call_operator(op, args, kwargs, meta) |
| 120 | + assert callable(op) |
| 121 | + |
| 122 | + torch_op_schemas = get_signature_for_torch_op(op._op) |
| 123 | + if len(torch_op_schemas) == 0: |
| 124 | + return super().call_operator(op, args, kwargs, meta) |
| 125 | + |
| 126 | + matched_schemas = [] |
| 127 | + # Iterate through all of the schema until we find one that matches |
| 128 | + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs |
| 129 | + # values. If none matches, `new_args_and_kwargs` will be None |
| 130 | + for candidate_signature in torch_op_schemas: |
| 131 | + try: |
| 132 | + candidate_signature.bind(*args, **kwargs) |
| 133 | + matched_schemas.append(candidate_signature) |
| 134 | + except TypeError: |
| 135 | + continue |
| 136 | + |
| 137 | + if len(matched_schemas) != 1: |
| 138 | + # Did not match any schema. Cannot normalize |
| 139 | + return super().call_operator(op, args, kwargs, meta) |
| 140 | + |
| 141 | + sig = matched_schemas[0] |
| 142 | + bound_args = sig.bind(*args, **kwargs) |
| 143 | + |
| 144 | + return super().call_operator(op, bound_args.args, bound_args.kwargs, meta) |
| 145 | + |
| 146 | + |
112 | 147 | # This class encapsulates all the functions that simplify the op's args |
113 | 148 | class CadenceSimplifyOpsInGraph: |
114 | 149 | passes = [ |
115 | 150 | SimplifySliceOpPass, |
| 151 | + BindOptionalArgsPass, |
116 | 152 | ] |
0 commit comments