Skip to content

Commit 77bcfb0

Browse files
committed
[Not for Commit] Example use of new dim_order api
Just for API usage illustration purposes
1 parent d01810e commit 77bcfb0

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

exir/tests/test_memory_format_ops_pass.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def test_op_dim_order_update(self) -> None:
109109
)
110110

111111
def test_op_dim_order_propagation(self) -> None:
112+
print("test_op_dim_order_propagation: unambiguous path")
112113
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
113114
self,
114115
MemoryFormatTestSet(
@@ -126,6 +127,24 @@ def test_op_dim_order_propagation(self) -> None:
126127
),
127128
)
128129

130+
print("test_op_dim_order_propagation: ambiguous path")
131+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
132+
self,
133+
MemoryFormatTestSet(
134+
module=PropagateToCopyChannalsLastModule().eval(),
135+
op=torch.ops.aten._to_copy.default,
136+
sample_input=(
137+
torch.rand_like(
138+
torch.zeros([2, 1, 2, 2]),
139+
dtype=torch.float32,
140+
memory_format=torch.contiguous_format,
141+
),
142+
),
143+
target_memory_format=torch.channels_last,
144+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
145+
),
146+
)
147+
129148
# Only test dim order replacement result in lean mode test.
130149
# This test is irrelevant with operator mode.
131150
def test_dim_order_replacement(self) -> None:

exir/tests/test_memory_format_ops_pass_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@
2020
is_channel_last_dim_order,
2121
is_contiguous_dim_order,
2222
)
23+
from executorch.exir.pass_base import ExportPass
24+
25+
from exir.passes.memory_format_ops_pass import MemoryFormatOpsPass
2326

2427
from torch.export import export
28+
29+
from torch.fx.passes.infra.pass_manager import PassManager
2530
from torch.testing import FileCheck
2631
from torch.utils._pytree import tree_flatten
2732

@@ -99,6 +104,50 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99104
return t1 * t2
100105

101106

107+
def assert_unambiguous_dim_order(gm):
108+
# This is just an example, you can add your own pass or passes.
109+
class ExampleNOPPass(ExportPass):
110+
"""
111+
Does nothing!
112+
"""
113+
114+
def call_operator(self, op, args, kwargs, meta):
115+
return super().call_operator(
116+
op,
117+
args,
118+
kwargs,
119+
meta,
120+
)
121+
122+
# This is an example of how one can detect ambiguous dim_order anywhere in the graph.
123+
# You can be surgical and only detect it in the nodes you are interested in or something else.
124+
def detect_ambiguity(gm):
125+
"""
126+
Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats.
127+
"""
128+
for node in gm.graph.nodes:
129+
if node.op == "call_function":
130+
tensor = node.meta["val"]
131+
# Let's make sure dim_order is not ambiguous, raise otherwise.
132+
# This is raising because we can't do anything about it.
133+
# The right course of follow up action is to ask user to try with a different example input.
134+
print(f"node: {node}, shape: {tensor.shape}, ", end="")
135+
136+
try:
137+
dim_order = tensor.dim_order(
138+
ambiguity_check=[torch.contiguous_format, torch.channels_last]
139+
)
140+
print(f"dim_order: {dim_order}")
141+
except Exception as e:
142+
print("")
143+
raise RuntimeError(e)
144+
145+
# any pass or passes, just using MemoryFormatOpsPass as an example
146+
dim_order_pass_manager = PassManager(passes=[ExampleNOPPass()])
147+
dim_order_pass_manager.add_checks(detect_ambiguity)
148+
dim_order_pass_manager(gm)
149+
150+
102151
class MemoryFormatOpsPassTestUtils:
103152
@staticmethod
104153
def memory_format_test_runner(
@@ -121,6 +170,9 @@ def memory_format_test_runner(
121170
before, compile_config=EdgeCompileConfig(_skip_dim_order=False)
122171
)
123172

173+
# Just as an example
174+
assert_unambiguous_dim_order(epm.exported_program().graph_module)
175+
124176
# check memory format ops, if needed
125177
if test_set.op_level_check:
126178
aten_op_str, edge_op_str = MemoryFormatOps2Str[test_set.op]
@@ -153,6 +205,8 @@ def memory_format_test_runner(
153205

154206
# check EdgeOp and the new BackendOp should behave the same in the runtime
155207
executorch_prog = epm.to_executorch()
208+
if test_set._load_for_executorch_from_buffer is None:
209+
return
156210

157211
executorch_module = test_set._load_for_executorch_from_buffer(
158212
executorch_prog.buffer

0 commit comments

Comments
 (0)