Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions exir/passes/dim_order_ops_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,14 @@ def _empty_dim_order_out_impl(*args, **kwargs):


"""
Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup
Defines a map of edge ops to the corresponding dim_order ops for quick lookup
"""
DimOrderOpsMap = {
"aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
"aten.empty.memory_format": exir_ops.edge.dim_order_ops._empty_dim_order.default,
exir_ops.edge.aten._to_copy.default: exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.aten.empty.memory_format: exir_ops.edge.dim_order_ops._empty_dim_order.default,
}

"""
Defines a map of aten or edge ops to the corresponding memory format ops for quick lookup
Defines a map of edge ops to the corresponding memory format ops for quick lookup, which is the revert of DimOrderOpsMap
"""
MemoryFormatOpsMap = {
"dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default,
"dim_order_ops._empty_dim_order.default": exir_ops.edge.aten.empty.memory_format,
}

# If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts.
assert len(DimOrderOpsMap) == len(MemoryFormatOpsMap)

# TODO stricter check for 1:1 mapping
MemoryFormatOpsMap = {v: k for k, v in DimOrderOpsMap.items()}
12 changes: 6 additions & 6 deletions exir/passes/memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class MemoryFormatOpsPass(ExportPass):
"""

def call_operator(self, op, args, kwargs, meta):
if not (isinstance(op, EdgeOpOverload) and op.__name__ in DimOrderOpsMap):
if not (isinstance(op, EdgeOpOverload) and op in DimOrderOpsMap):
return super().call_operator(
op,
args,
Expand Down Expand Up @@ -61,10 +61,10 @@ def call_operator(self, op, args, kwargs, meta):
nkwargs["dim_order"] = get_dim_order(mem_format, ndim)
logger.debug(
f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}."
f" {DimOrderOpsMap[op.__name__].__name__} = dim_order: {nkwargs['dim_order']}"
f" {DimOrderOpsMap[op].__name__} = dim_order: {nkwargs['dim_order']}"
)

t = DimOrderOpsMap[op.__name__]
t = DimOrderOpsMap[op]

return super().call_operator(
t,
Expand All @@ -80,7 +80,7 @@ class DimOrderOpsRevertPass(ExportPass):
"""

def call_operator(self, op, args, kwargs, meta):
if not (isinstance(op, EdgeOpOverload) and op.__name__ in MemoryFormatOpsMap):
if not (isinstance(op, EdgeOpOverload) and op in MemoryFormatOpsMap):
return super().call_operator(
op,
args,
Expand Down Expand Up @@ -109,10 +109,10 @@ def call_operator(self, op, args, kwargs, meta):

logger.debug(
f" {op.__name__} = dim_order: {dim_order}."
f" {MemoryFormatOpsMap[op.__name__].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
f" {MemoryFormatOpsMap[op].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
)

t = MemoryFormatOpsMap[op.__name__]
t = MemoryFormatOpsMap[op]

return super().call_operator(
t,
Expand Down
5 changes: 3 additions & 2 deletions exir/verification/test/test_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ def __init__(self):

def forward(self, x: torch.Tensor) -> torch.Tensor:
t1 = x.to(dtype=torch.double, memory_format=torch.channels_last)
t2 = t1 + t1
return t1 * t2
t2 = torch.empty(t1.size(), memory_format=torch.channels_last)
t2.copy_(t1)
return t2

m = Model().eval()

Expand Down
7 changes: 5 additions & 2 deletions exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.error import ExportError, ExportErrorType
from executorch.exir.lowered_backend_module import LoweredBackendModule
from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap
from executorch.exir.verification.arg_validator import (
EdgeOpArgValidator,
RunHigherOrderOperatorError,
)

from torch._dispatch.python import enable_python_dispatcher
from torch._export.utils import _detect_fake_mode_from_gm

Expand All @@ -44,7 +46,7 @@ def _check_tensors_are_contiguous(gm: GraphModule) -> None:

def _check_valid_dim_order_ops(op, use_dim_order) -> None:
if use_dim_order:
if op in (torch.ops.aten._to_copy.default,):
if op in DimOrderOpsMap:
raise SpecViolationError(f"{op} should not be used in dim_order mode")
else: # not using dim_order
if op.namespace in ("dim_order_ops",):
Expand Down Expand Up @@ -179,6 +181,7 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
if validator.violating_ops:
raise SpecViolationError(
f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}"
"Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding "
)


Expand Down Expand Up @@ -248,7 +251,7 @@ def check_valid_edge_op(self, op):
)
)
if isinstance(op, EdgeOpOverload):
_check_valid_dim_order_ops(op._op, self.use_dim_order)
_check_valid_dim_order_ops(op, self.use_dim_order)
self.check_valid_aten_op(op._op)

if isinstance(op, types.FunctionType):
Expand Down
Loading