diff --git a/exir/verification/test/test_verifier.py b/exir/verification/test/test_verifier.py index b2e31dbc599..1ee48ef4d43 100644 --- a/exir/verification/test/test_verifier.py +++ b/exir/verification/test/test_verifier.py @@ -117,8 +117,9 @@ def __init__(self): def forward(self, x: torch.Tensor) -> torch.Tensor: t1 = x.to(dtype=torch.double, memory_format=torch.channels_last) - t2 = t1 + t1 - return t1 * t2 + t2 = torch.empty(t1.size(), memory_format=torch.channels_last) + t2.copy_(t1) + return t2 m = Model().eval() diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index f906623ca25..4a37b457e30 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -15,10 +15,12 @@ from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.error import ExportError, ExportErrorType from executorch.exir.lowered_backend_module import LoweredBackendModule +from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap from executorch.exir.verification.arg_validator import ( EdgeOpArgValidator, RunHigherOrderOperatorError, ) + from torch._dispatch.python import enable_python_dispatcher from torch._export.utils import _detect_fake_mode_from_gm @@ -44,7 +46,7 @@ def _check_tensors_are_contiguous(gm: GraphModule) -> None: def _check_valid_dim_order_ops(op, use_dim_order) -> None: if use_dim_order: - if op in (torch.ops.aten._to_copy.default,): + if op in DimOrderOpsMap: raise SpecViolationError(f"{op} should not be used in dim_order mode") else: # not using dim_order if op.namespace in ("dim_order_ops",): @@ -249,7 +251,7 @@ def check_valid_edge_op(self, op): ) ) if isinstance(op, EdgeOpOverload): - _check_valid_dim_order_ops(op._op, self.use_dim_order) + _check_valid_dim_order_ops(op, self.use_dim_order) self.check_valid_aten_op(op._op) if isinstance(op, types.FunctionType):