|
15 | 15 | from executorch.exir.dialects.edge._ops import EdgeOpOverload |
16 | 16 | from executorch.exir.error import ExportError, ExportErrorType |
17 | 17 | from executorch.exir.lowered_backend_module import LoweredBackendModule |
| 18 | +from executorch.exir.pass_base import ExportPass |
| 19 | +from executorch.exir.pass_manager import PassManager |
18 | 20 | from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap |
19 | 21 | from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS |
20 | 22 | from executorch.exir.passes.replace_aten_with_edge_pass import DISALLOW_LIST |
|
35 | 37 |
|
36 | 38 | ALLOWED_META_KEYS = {"spec", "stack_trace"} |
37 | 39 |
|
| 40 | +class AmbiguousDimOrderError(RuntimeError): |
| 41 | + def __init__(self, message: str) -> None: |
| 42 | + super().__init__(message) |
| 43 | + |
| 44 | +def assert_unambiguous_dim_order(gm): |
| 45 | + class ExampleNOPPass(ExportPass): |
| 46 | + """ |
| 47 | + Does nothing! |
| 48 | + """ |
| 49 | + |
| 50 | + def call_operator(self, op, args, kwargs, meta): |
| 51 | + return super().call_operator( |
| 52 | + op, |
| 53 | + args, |
| 54 | + kwargs, |
| 55 | + meta, |
| 56 | + ) |
| 57 | + |
| 58 | + # This is an example of how one can detect ambiguous dim_order anywhere in the graph. |
| 59 | + # You can be surgical and only detect it in the nodes you are interested in or something else. |
| 60 | + def detect_ambiguity(gm): |
| 61 | + """ |
| 62 | + Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats. |
| 63 | + """ |
| 64 | + |
| 65 | + def get_tensors(node: torch.fx.Node) -> List[torch.Tensor]: |
| 66 | + val = node.meta["val"] |
| 67 | + if isinstance(val, torch.Tensor): |
| 68 | + return [val] |
| 69 | + elif isinstance(val, (list, tuple)): |
| 70 | + return [tensor for tensor in val if isinstance(tensor, torch.Tensor)] |
| 71 | + return [] |
| 72 | + |
| 73 | + for node in gm.graph.nodes: |
| 74 | + if node.op == "call_function": |
| 75 | + for tensor in get_tensors(node): |
| 76 | + # Let's make sure dim_order is not ambiguous, raise otherwise. |
| 77 | + # This is raising because we can't do anything about it. |
| 78 | + # The right course of follow up action is to ask user to try with a different example input. |
| 79 | + try: |
| 80 | + _ = tensor.dim_order( |
| 81 | + ambiguity_check=[ |
| 82 | + torch.contiguous_format, |
| 83 | + torch.channels_last, |
| 84 | + ] |
| 85 | + ) |
| 86 | + except Exception: |
| 87 | + raise AmbiguousDimOrderError |
| 88 | + |
| 89 | + # any pass or passes, just using MemoryFormatOpsPass as an example |
| 90 | + dim_order_pass_manager = PassManager(passes=[ExampleNOPPass()]) |
| 91 | + dim_order_pass_manager.add_checks(detect_ambiguity) |
| 92 | + dim_order_pass_manager(gm) |
38 | 93 |
|
39 | 94 | def _check_tensors_are_contiguous(gm: GraphModule) -> None: |
40 | 95 | # Tensors be of contiguous format |
@@ -281,7 +336,9 @@ def check_additional(self, gm: GraphModule) -> None: |
281 | 336 | if self.check_edge_ops: |
282 | 337 | _check_tensors_are_contiguous(gm) |
283 | 338 | _check_tensor_args_matching_op_allowed_dtype(gm) |
284 | | - |
| 339 | + if self.use_dim_order: |
| 340 | + assert_unambiguous_dim_order(gm) |
| 341 | + |
285 | 342 | def is_valid(self, gm: GraphModule) -> bool: |
286 | 343 | try: |
287 | 344 | self(gm) |
|
0 commit comments