Skip to content

Commit 1df8972

Browse files
committed
Added tensor's dim order ambiguity check
1 parent a664d7b commit 1df8972

File tree

1 file changed

+58
-1
lines changed

1 file changed

+58
-1
lines changed

exir/verification/verifier.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1616
from executorch.exir.error import ExportError, ExportErrorType
1717
from executorch.exir.lowered_backend_module import LoweredBackendModule
18+
from executorch.exir.pass_base import ExportPass
19+
from executorch.exir.pass_manager import PassManager
1820
from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap
1921
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
2022
from executorch.exir.passes.replace_aten_with_edge_pass import DISALLOW_LIST
@@ -35,6 +37,59 @@
3537

3638
ALLOWED_META_KEYS = {"spec", "stack_trace"}
3739

40+
class AmbiguousDimOrderError(RuntimeError):
41+
def __init__(self, message: str) -> None:
42+
super().__init__(message)
43+
44+
def assert_unambiguous_dim_order(gm):
45+
class ExampleNOPPass(ExportPass):
46+
"""
47+
Does nothing!
48+
"""
49+
50+
def call_operator(self, op, args, kwargs, meta):
51+
return super().call_operator(
52+
op,
53+
args,
54+
kwargs,
55+
meta,
56+
)
57+
58+
# This is an example of how one can detect ambiguous dim_order anywhere in the graph.
59+
# You can be surgical and only detect it in the nodes you are interested in or something else.
60+
def detect_ambiguity(gm):
61+
"""
62+
Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats.
63+
"""
64+
65+
def get_tensors(node: torch.fx.Node) -> List[torch.Tensor]:
66+
val = node.meta["val"]
67+
if isinstance(val, torch.Tensor):
68+
return [val]
69+
elif isinstance(val, (list, tuple)):
70+
return [tensor for tensor in val if isinstance(tensor, torch.Tensor)]
71+
return []
72+
73+
for node in gm.graph.nodes:
74+
if node.op == "call_function":
75+
for tensor in get_tensors(node):
76+
# Let's make sure dim_order is not ambiguous, raise otherwise.
77+
# This is raising because we can't do anything about it.
78+
# The right course of follow up action is to ask user to try with a different example input.
79+
try:
80+
_ = tensor.dim_order(
81+
ambiguity_check=[
82+
torch.contiguous_format,
83+
torch.channels_last,
84+
]
85+
)
86+
except Exception:
87+
raise AmbiguousDimOrderError
88+
89+
# any pass or passes, just using MemoryFormatOpsPass as an example
90+
dim_order_pass_manager = PassManager(passes=[ExampleNOPPass()])
91+
dim_order_pass_manager.add_checks(detect_ambiguity)
92+
dim_order_pass_manager(gm)
3893

3994
def _check_tensors_are_contiguous(gm: GraphModule) -> None:
4095
# Tensors be of contiguous format
@@ -281,7 +336,9 @@ def check_additional(self, gm: GraphModule) -> None:
281336
if self.check_edge_ops:
282337
_check_tensors_are_contiguous(gm)
283338
_check_tensor_args_matching_op_allowed_dtype(gm)
284-
339+
if self.use_dim_order:
340+
assert_unambiguous_dim_order(gm)
341+
285342
def is_valid(self, gm: GraphModule) -> bool:
286343
try:
287344
self(gm)

0 commit comments

Comments
 (0)