Skip to content

Commit b886504

Browse files
[et][dim order] Makes DimOrderOpsMap as an operator to operator mapping
Pull Request resolved: #7187 This diff updates DimOrderOpsMap from name-to-operator mapping to operator-to-operator mapping, which has multiple benefits: 1. Reduce dialect ambiguity. Different dialects op may map to same name (e.g. `aten.to_copy` and `exir_ops.edge.aten._to_copy.default`). Directly using op can diminish the ambiguity. 2. Auto-maintain MemoryFormatOpsMap by reverting DimOrderOpsMap ghstack-source-id: 256633613 @exported-using-ghexport Differential Revision: [D66773612](https://our.internmc.facebook.com/intern/diff/D66773612/) Co-authored-by: gasoonjia <[email protected]>
1 parent 1ac6ee5 commit b886504

File tree

3 files changed

+12
-19
lines changed

3 files changed

+12
-19
lines changed

exir/passes/dim_order_ops_registry.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,14 @@ def _empty_dim_order_out_impl(*args, **kwargs):
5858

5959

6060
"""
61-
Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup
61+
Defines a map of edge ops to the corresponding dim_order ops for quick lookup
6262
"""
6363
DimOrderOpsMap = {
64-
"aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
65-
"aten.empty.memory_format": exir_ops.edge.dim_order_ops._empty_dim_order.default,
64+
exir_ops.edge.aten._to_copy.default: exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
65+
exir_ops.edge.aten.empty.memory_format: exir_ops.edge.dim_order_ops._empty_dim_order.default,
6666
}
6767

6868
"""
69-
Defines a map of aten or edge ops to the corresponding memory format ops for quick lookup
69+
Defines a map of edge ops to the corresponding memory format ops for quick lookup, which is the revert of DimOrderOpsMap
7070
"""
71-
MemoryFormatOpsMap = {
72-
"dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default,
73-
"dim_order_ops._empty_dim_order.default": exir_ops.edge.aten.empty.memory_format,
74-
}
75-
76-
# If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts.
77-
assert len(DimOrderOpsMap) == len(MemoryFormatOpsMap)
78-
79-
# TODO stricter check for 1:1 mapping
71+
MemoryFormatOpsMap = {v: k for k, v in DimOrderOpsMap.items()}

exir/passes/memory_format_ops_pass.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class MemoryFormatOpsPass(ExportPass):
3232
"""
3333

3434
def call_operator(self, op, args, kwargs, meta):
35-
if not (isinstance(op, EdgeOpOverload) and op.__name__ in DimOrderOpsMap):
35+
if not (isinstance(op, EdgeOpOverload) and op in DimOrderOpsMap):
3636
return super().call_operator(
3737
op,
3838
args,
@@ -61,10 +61,10 @@ def call_operator(self, op, args, kwargs, meta):
6161
nkwargs["dim_order"] = get_dim_order(mem_format, ndim)
6262
logger.debug(
6363
f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}."
64-
f" {DimOrderOpsMap[op.__name__].__name__} = dim_order: {nkwargs['dim_order']}"
64+
f" {DimOrderOpsMap[op].__name__} = dim_order: {nkwargs['dim_order']}"
6565
)
6666

67-
t = DimOrderOpsMap[op.__name__]
67+
t = DimOrderOpsMap[op]
6868

6969
return super().call_operator(
7070
t,
@@ -80,7 +80,7 @@ class DimOrderOpsRevertPass(ExportPass):
8080
"""
8181

8282
def call_operator(self, op, args, kwargs, meta):
83-
if not (isinstance(op, EdgeOpOverload) and op.__name__ in MemoryFormatOpsMap):
83+
if not (isinstance(op, EdgeOpOverload) and op in MemoryFormatOpsMap):
8484
return super().call_operator(
8585
op,
8686
args,
@@ -109,10 +109,10 @@ def call_operator(self, op, args, kwargs, meta):
109109

110110
logger.debug(
111111
f" {op.__name__} = dim_order: {dim_order}."
112-
f" {MemoryFormatOpsMap[op.__name__].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
112+
f" {MemoryFormatOpsMap[op].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
113113
)
114114

115-
t = MemoryFormatOpsMap[op.__name__]
115+
t = MemoryFormatOpsMap[op]
116116

117117
return super().call_operator(
118118
t,

exir/verification/verifier.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
179179
if validator.violating_ops:
180180
raise SpecViolationError(
181181
f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}"
182+
"Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding "
182183
)
183184

184185

0 commit comments

Comments
 (0)