diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index bf4a274134d..c7cea31b492 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -33,7 +33,6 @@ ExecutorchProgramManager, to_edge, ) -from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult from executorch.exir.passes import ToOutVarPass from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass @@ -187,6 +186,7 @@ def export_to_edge( edge_prog_manager = to_edge( expo_program, compile_config=EdgeCompileConfig( + _skip_dim_order=True, # Allow specific non-core aten ops in the IR. _core_aten_ops_exception_list=[ torch.ops.aten._native_batch_norm_legit_functional.default, @@ -194,10 +194,6 @@ def export_to_edge( torch.ops.aten.linalg_vector_norm.default, torch.ops.aten.unfold.default, torch.ops.aten.angle.default, - # cadence replaced to_dim_order_copy with _to_copy for performance - # skip _to_copy op to get around of dim order check - # We should remove this op once cadence can support dim order - exir_ops.edge.aten._to_copy.default, ], ), constant_methods=constant_methods, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index cc304a226a6..89ef821c569 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -11,7 +11,6 @@ # pyre-unsafe -import copy import math from operator import neg from typing import cast, Dict, Iterable, Sequence, Set, Tuple @@ -36,12 +35,7 @@ from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket -from executorch.exir.dim_order_utils import get_memory_format from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue -from executorch.exir.passes.dim_order_ops_registry import ( - DimOrderOpsMap, - MemoryFormatOpsMap, -) from torch._subclasses import FakeTensor from torch.fx.node import Argument @@ -1805,72 +1799,6 @@ def call_operator( ) -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceToDimOrderCopyWithToCopyPass(ExportPass): - """ - dim_order_ops::to_dim_order_copy is not supported, so this is an opt_level=0 pass. - If the dim order is sequential, we don't need the extra work with strides and - can just use to_copy. - """ - - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in DimOrderOpsMap: - return super().call_operator(op, args, kwargs, meta) - - # new kwargs with dim_order, and no memory_format for the new op - nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable - - ndim = None - - # can always get the shape, assuming rank is specialized - - # pyre-ignore[16]: `None` has no attribute `to_tensor` - if isinstance(args[0], ProxyValue) and args[0].is_tensor(): - # pyre-ignore[16]: `None` has no attribute `to_tensor` - ndim = args[0].to_tensor().dim() - elif isinstance(args[0], torch.Tensor): - # pyre-ignore[16]: `None` has no attribute `dim` - ndim = args[0].dim() - elif isinstance(args[0], torch.fx.immutable_collections.immutable_list): - # pyre-ignore[6]: Incompatible parameter type - ndim = len(args[0]) - else: - assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}" - - # get the "to" memory format for the EdgeOp - contiguous_dim_order = list(range(ndim)) - dim_order = nkwargs.pop("dim_order", None) - - # Cadence only supports contiguous memory format - assert ( - dim_order is None - # pyre-ignore[6]: Incompatible parameter type - or len(dim_order) == 0 - or dim_order == contiguous_dim_order - ), "Expected dim order in congituous or prevserve memory format, but got {}".format( - dim_order - ) - - # bring back memory format - # pyre-ignore[6]: Incompatible parameter type - nkwargs["memory_format"] = get_memory_format(dim_order) - - memory_format_op = MemoryFormatOpsMap[op] - - return super().call_operator( - memory_format_op, - args, - nkwargs, - meta, - ) - - @register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceFullLikeWithFullPass(ExportPass): """ @@ -2180,5 +2108,4 @@ class CadenceReplaceOpsInGraph: ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, ReplaceAtenAvgPoolWithJarvisAvgPoolPass, ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass, - ReplaceToDimOrderCopyWithToCopyPass, ]