Skip to content
63 changes: 62 additions & 1 deletion exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.error import ExportError, ExportErrorType
from executorch.exir.lowered_backend_module import LoweredBackendModule
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_manager import PassManager
from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
from executorch.exir.passes.replace_aten_with_edge_pass import DISALLOW_LIST
Expand All @@ -35,6 +37,63 @@

ALLOWED_META_KEYS = {"spec", "stack_trace"}

class AmbiguousDimOrderError(RuntimeError):
'''
Returns an Ambiguous Dimension Order Error when any node's output tensor dim_order
is ambiguous for a list of formats.
'''
def __init__(self, message: str) -> None:
super().__init__(message)

def assert_unambiguous_dim_order(gm):
class ExampleNOPPass(ExportPass):
"""
Does nothing!
"""

def call_operator(self, op, args, kwargs, meta):
return super().call_operator(
op,
args,
kwargs,
meta,
)

# This is an example of how one can detect ambiguous dim_order anywhere in the graph.
# You can be surgical and only detect it in the nodes you are interested in or something else.
def detect_ambiguity(gm):
"""
Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats.
"""

def get_tensors(node: torch.fx.Node) -> List[torch.Tensor]:
val = node.meta["val"]
if isinstance(val, torch.Tensor):
return [val]
elif isinstance(val, (list, tuple)):
return [tensor for tensor in val if isinstance(tensor, torch.Tensor)]
return []

for node in gm.graph.nodes:
if node.op == "call_function":
for tensor in get_tensors(node):
# Let's make sure dim_order is not ambiguous, raise otherwise.
# This is raising because we can't do anything about it.
# The right course of follow up action is to ask user to try with a different example input.
try:
_ = tensor.dim_order(
ambiguity_check=[
torch.contiguous_format,
torch.channels_last,
]
)
except Exception:
raise AmbiguousDimOrderError("Tensors should not have ambigous dim order, try with a different example input")

# any pass or passes, just using MemoryFormatOpsPass as an example
dim_order_pass_manager = PassManager(passes=[ExampleNOPPass()])
dim_order_pass_manager.add_checks(detect_ambiguity)
dim_order_pass_manager(gm)

def _check_tensors_are_contiguous(gm: GraphModule) -> None:
# Tensors be of contiguous format
Expand Down Expand Up @@ -281,7 +340,9 @@ def check_additional(self, gm: GraphModule) -> None:
if self.check_edge_ops:
_check_tensors_are_contiguous(gm)
_check_tensor_args_matching_op_allowed_dtype(gm)

if self.use_dim_order:
assert_unambiguous_dim_order(gm)

def is_valid(self, gm: GraphModule) -> bool:
try:
self(gm)
Expand Down
Loading