|
16 | 16 | CadencePassAttribute,
|
17 | 17 | register_cadence_pass,
|
18 | 18 | )
|
19 |
| - |
20 | 19 | from executorch.exir.dialects._ops import ops as exir_ops
|
| 20 | +from executorch.exir.dialects.edge._ops import EdgeOpOverload |
21 | 21 | from executorch.exir.pass_base import ExportPass, ProxyValue
|
| 22 | +from torch.fx.operator_schemas import get_signature_for_torch_op |
22 | 23 |
|
23 | 24 |
|
24 | 25 | @register_cadence_pass(CadencePassAttribute(opt_level=0))
|
@@ -109,8 +110,44 @@ def call_operator(self, op, args, kwargs, meta):
|
109 | 110 | return super().call_operator(op, new_args, kwargs, meta)
|
110 | 111 |
|
111 | 112 |
|
| 113 | +@register_cadence_pass(CadencePassAttribute(opt_level=0)) |
| 114 | +class BindOptionalArgsPass(ExportPass): |
| 115 | + """Bind all optional args and kwargs.""" |
| 116 | + |
| 117 | + def call_operator(self, op, args, kwargs, meta): |
| 118 | + if not isinstance(op, EdgeOpOverload): |
| 119 | + return super().call_operator(op, args, kwargs, meta) |
| 120 | + assert callable(op) |
| 121 | + |
| 122 | + torch_op_schemas = get_signature_for_torch_op(op._op) |
| 123 | + if len(torch_op_schemas) == 0: |
| 124 | + return super().call_operator(op, args, kwargs, meta) |
| 125 | + |
| 126 | + matched_schemas = [] |
| 127 | + # Iterate through all of the schema until we find one that matches |
| 128 | + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs |
| 129 | + # values. If none matches, `new_args_and_kwargs` will be None |
| 130 | + for candidate_signature in torch_op_schemas: |
| 131 | + try: |
| 132 | + candidate_signature.bind(*args, **kwargs) |
| 133 | + matched_schemas.append(candidate_signature) |
| 134 | + except TypeError: |
| 135 | + continue |
| 136 | + |
| 137 | + if len(matched_schemas) != 1: |
| 138 | + # Did not match any schema. Cannot normalize |
| 139 | + return super().call_operator(op, args, kwargs, meta) |
| 140 | + |
| 141 | + sig = matched_schemas[0] |
| 142 | + bound_args = sig.bind(*args, **kwargs) |
| 143 | + bound_args.apply_defaults() |
| 144 | + |
| 145 | + return super().call_operator(op, bound_args.args, bound_args.kwargs, meta) |
| 146 | + |
| 147 | + |
112 | 148 | # This class encapsulates all the functions that simplify the op's args
|
113 | 149 | class CadenceSimplifyOpsInGraph:
|
114 | 150 | passes = [
|
115 | 151 | SimplifySliceOpPass,
|
| 152 | + BindOptionalArgsPass, |
116 | 153 | ]
|
0 commit comments