diff --git a/exir/passes/dim_order_ops_registry.py b/exir/passes/dim_order_ops_registry.py index 27fc03f9413..c4436aaa910 100644 --- a/exir/passes/dim_order_ops_registry.py +++ b/exir/passes/dim_order_ops_registry.py @@ -15,11 +15,19 @@ "_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor" ) -# Out variant drops TensorOptions +lib.define( + "_empty_dim_order(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, int[]? dim_order=None) -> Tensor" +) + +# Out variant of aten::_to_copy and aten::empty drops TensorOptions, so do their dim order variants lib.define( "_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" +) + def _op_impl(target, *args, **kwargs): kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None)) @@ -39,11 +47,22 @@ def _to_dim_order_copy_out_impl(*args, **kwargs): return _op_impl(torch.ops.aten._to_copy.out, *args, **kwargs) +@impl(lib, "_empty_dim_order", "CompositeImplicitAutograd") +def _empty_dim_order_impl(*args, **kwargs): + return _op_impl(torch.ops.aten.empty.memory_format, *args, **kwargs) + + +@impl(lib, "_empty_dim_order.out", "CompositeImplicitAutograd") +def _empty_dim_order_out_impl(*args, **kwargs): + return _op_impl(torch.ops.aten.empty.out, *args, **kwargs) + + """ Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup """ DimOrderOpsMap = { "aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + "aten.empty.memory_format": exir_ops.edge.dim_order_ops._empty_dim_order.default, } """ @@ -51,6 +70,7 @@ def _to_dim_order_copy_out_impl(*args, **kwargs): """ MemoryFormatOpsMap = { "dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default, + "dim_order_ops._empty_dim_order.default": exir_ops.edge.aten.empty.memory_format, } # If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts. diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index 32678bf4082..ba89a510a71 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -39,6 +39,7 @@ def call_operator(self, op, args, kwargs, meta): kwargs, meta, ) + # new kwargs with dim_order, and no memory_format for the new op nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable @@ -50,17 +51,20 @@ def call_operator(self, op, args, kwargs, meta): ndim = args[0].to_tensor().dim() elif isinstance(args[0], torch.Tensor): ndim = args[0].dim() + elif isinstance(args[0], torch.fx.immutable_collections.immutable_list): + ndim = len(args[0]) else: - assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}" + assert ( + 0 + ), f"Expecting a Tensor, a ProxyValue, or a Sequence, but got {type(args[0])}" nkwargs["dim_order"] = get_dim_order(mem_format, ndim) logger.debug( - f"_to_copy = rank: {ndim}, memory_format: {mem_format}." - f" _to_dim_order_copy = dim_order: {nkwargs['dim_order']}" + f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}." + f" {DimOrderOpsMap[op.__name__].__name__} = dim_order: {nkwargs['dim_order']}" ) - t = DimOrderOpsMap.get(op.__name__, None) - assert t is not None, f"{op.__name__} not found in DimOrderOpsMap" + t = DimOrderOpsMap[op.__name__] return super().call_operator( t, @@ -92,8 +96,10 @@ def call_operator(self, op, args, kwargs, meta): ndim = args[0].to_tensor().dim() elif isinstance(args[0], torch.Tensor): ndim = args[0].dim() + elif isinstance(args[0], torch.fx.immutable_collections.immutable_list): + ndim = len(args[0]) else: - assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}" + assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}" # get the "to" memory format for the EdgeOp default_dim_order = list(range(ndim)) @@ -102,12 +108,11 @@ def call_operator(self, op, args, kwargs, meta): nkwargs["memory_format"] = get_memory_format(dim_order) logger.debug( - f" _to_dim_order_copy = dim_order: {dim_order}." - f"_to_copy = rank: {ndim}, memory_format: {nkwargs['memory_format']}." + f" {op.__name__} = dim_order: {dim_order}." + f" {MemoryFormatOpsMap[op.__name__].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}." ) - t = MemoryFormatOpsMap.get(op.__name__, None) - assert t is not None, f"{op.__name__} not found in MemoryFormatOpsMap" + t = MemoryFormatOpsMap[op.__name__] return super().call_operator( t, diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 53befded94b..0292cf98f50 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -27,6 +27,8 @@ MemoryFormatOpsPassTestUtils, MemoryFormatTestSet, PropagateToCopyChannalsLastModule, + SimpleEmptyChannelLastModule, + SimpleEmptyContiguoustModule, SimpleToCopyChannelsLastModule, SimpleToCopyContiguousModule, ) @@ -45,6 +47,7 @@ def test_op_to_copy_replacement_2d(self) -> None: self, MemoryFormatTestSet( module=SimpleToCopyContiguousModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),), target_memory_format=torch.contiguous_format, _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, @@ -56,17 +59,43 @@ def test_op_to_copy_replacement_4d(self) -> None: self, MemoryFormatTestSet( module=SimpleToCopyContiguousModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),), target_memory_format=torch.contiguous_format, _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, ), ) + def test_op_empty_replacement_channels_last(self) -> None: + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=SimpleEmptyChannelLastModule().eval(), + op=torch.ops.aten.empty.memory_format, + sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),), + target_memory_format=torch.channels_last, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + + def test_op_empty_replacement_contiguous(self) -> None: + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=SimpleEmptyContiguoustModule().eval(), + op=torch.ops.aten.empty.memory_format, + sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),), + target_memory_format=torch.contiguous_format, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + def test_op_dim_order_update(self) -> None: MemoryFormatOpsPassTestUtils.memory_format_test_runner( self, MemoryFormatTestSet( module=SimpleToCopyChannelsLastModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=( torch.rand_like( torch.zeros([2, 2, 2, 2]), @@ -84,6 +113,7 @@ def test_op_dim_order_propagation(self) -> None: self, MemoryFormatTestSet( module=PropagateToCopyChannalsLastModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=( torch.rand_like( torch.zeros([2, 2, 2, 2]), @@ -273,6 +303,7 @@ def test_resnet18(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, @@ -288,6 +319,7 @@ def test_resnet18_xnnpack(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, @@ -304,6 +336,7 @@ def test_mobilenet_v3(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, @@ -319,6 +352,7 @@ def test_mobilenet_v3_xnnpack(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, diff --git a/exir/tests/test_memory_format_ops_pass_aten.py b/exir/tests/test_memory_format_ops_pass_aten.py index 601893fd238..5aa687e6aef 100644 --- a/exir/tests/test_memory_format_ops_pass_aten.py +++ b/exir/tests/test_memory_format_ops_pass_aten.py @@ -13,6 +13,8 @@ MemoryFormatOpsPassTestUtils, MemoryFormatTestSet, PropagateToCopyChannalsLastModule, + SimpleEmptyChannelLastModule, + SimpleEmptyContiguoustModule, SimpleToCopyChannelsLastModule, SimpleToCopyContiguousModule, ) @@ -28,6 +30,7 @@ def test_op_to_copy_replacement_2d_aten(self) -> None: self, MemoryFormatTestSet( module=SimpleToCopyContiguousModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn([3, 4, 5], dtype=torch.float32),), target_memory_format=torch.contiguous_format, _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, @@ -39,17 +42,43 @@ def test_op_to_copy_replacement_4d_aten(self) -> None: self, MemoryFormatTestSet( module=SimpleToCopyContiguousModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn([3, 4, 5, 6], dtype=torch.float32),), target_memory_format=torch.contiguous_format, _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, ), ) + def test_op_empty_replacement_channels_last(self) -> None: + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=SimpleEmptyChannelLastModule().eval(), + op=torch.ops.aten.empty.memory_format, + sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),), + target_memory_format=torch.channels_last, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + + def test_op_empty_replacement_contiguous(self) -> None: + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=SimpleEmptyContiguoustModule().eval(), + op=torch.ops.aten.empty.memory_format, + sample_input=(torch.randn((1, 10, 24, 24), dtype=torch.float32),), + target_memory_format=torch.contiguous_format, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + def test_op_dim_order_update_aten(self) -> None: MemoryFormatOpsPassTestUtils.memory_format_test_runner( self, MemoryFormatTestSet( module=SimpleToCopyChannelsLastModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=( torch.rand_like( torch.zeros([2, 2, 2, 2]), @@ -67,6 +96,7 @@ def test_op_dim_order_propagation_aten(self) -> None: self, MemoryFormatTestSet( module=PropagateToCopyChannalsLastModule().eval(), + op=torch.ops.aten._to_copy.default, sample_input=( torch.rand_like( torch.zeros([2, 2, 2, 2]), @@ -85,6 +115,7 @@ def test_resnet18(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, @@ -100,6 +131,7 @@ def test_mobilenet_v3(self) -> None: self, MemoryFormatTestSet( module=model.eval(), + op=torch.ops.aten._to_copy.default, sample_input=(torch.randn(1, 3, 224, 224),), target_memory_format=torch.contiguous_format, op_level_check=False, diff --git a/exir/tests/test_memory_format_ops_pass_utils.py b/exir/tests/test_memory_format_ops_pass_utils.py index 8ae5c0190a4..3049f30a8cb 100644 --- a/exir/tests/test_memory_format_ops_pass_utils.py +++ b/exir/tests/test_memory_format_ops_pass_utils.py @@ -8,7 +8,7 @@ import unittest from dataclasses import dataclass -from typing import Any, Tuple +from typing import Any, Dict, List, Tuple import torch @@ -26,11 +26,24 @@ from torch.utils._pytree import tree_flatten +MemoryFormatOps2Str: Dict[torch._ops.OpOverload, List[str]] = { + torch.ops.aten._to_copy.default: ( + "torch.ops.aten._to_copy.default", + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default", + ), + torch.ops.aten.empty.memory_format: ( + "torch.ops.aten.empty.memory_format", + "executorch_exir_dialects_edge__ops_dim_order_ops__empty_dim_order_default", + ), +} + + @dataclass class MemoryFormatTestSet: module: torch.nn.Module sample_input: Tuple[Any, ...] target_memory_format: torch.memory_format + op: torch._ops.OpOverload _load_for_executorch_from_buffer: Any op_level_check: bool = True use_xnnpack: bool = False @@ -54,6 +67,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.to(dtype=torch.double, memory_format=torch.channels_last) +class SimpleEmptyContiguoustModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + empty_tensor = torch.empty(x.size(), memory_format=torch.contiguous_format) + x = x.to(memory_format=torch.contiguous_format) + empty_tensor.copy_(x) + return empty_tensor + + +class SimpleEmptyChannelLastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + empty_tensor = torch.empty(x.size(), memory_format=torch.channels_last) + x = x.to(memory_format=torch.channels_last) + empty_tensor.copy_(x) + return empty_tensor + + class PropagateToCopyChannalsLastModule(torch.nn.Module): def __init__(self): super().__init__() @@ -86,9 +121,7 @@ def memory_format_test_runner( # check memory format ops, if needed if test_set.op_level_check: - aten_op_str = "torch.ops.aten._to_copy.default" - edge_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" - + aten_op_str, edge_op_str = MemoryFormatOps2Str[test_set.op] # check op strings before FileCheck().check_count(aten_op_str, 1, exactly=True).check_not( edge_op_str @@ -126,6 +159,7 @@ def memory_format_test_runner( runtime_output = executorch_module.run_method( "forward", tuple(inputs_flattened) )[0] + test_class.assertTrue( torch.allclose( runtime_output, expected, atol=test_set.atol, rtol=test_set.rtol diff --git a/kernels/aten/cpu/op__empty_dim_order.cpp b/kernels/aten/cpu/op__empty_dim_order.cpp new file mode 100644 index 00000000000..f11f853daa2 --- /dev/null +++ b/kernels/aten/cpu/op__empty_dim_order.cpp @@ -0,0 +1,110 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using exec_aten::IntArrayRef; +using exec_aten::Tensor; +using OptionalIntArrayRef = exec_aten::OptionalArrayRef; +using DimOrderArrayRef = exec_aten::ArrayRef; +// Out Aten tensor shall have same memory format stride as dim_order +const size_t kMaxNumOfDimensions = 16; + +namespace { + +inline bool _check__empty_out_dim_order( + OptionalIntArrayRef dim_order, + Tensor& out) { + exec_aten::ArrayRef dim_order_ref; + std::vector dim_order_vec; + + if (dim_order.has_value()) { + // out tensor's dim order shall equal to input dim order + dim_order_ref = exec_aten::ArrayRef( + dim_order.value().data(), dim_order.value().size()); + } else { // dim_order is not set, out tensor should be contiguous dim order + for (int i = 0; i < out.dim(); i++) { + dim_order_vec.push_back(i); + } + dim_order_ref = exec_aten::ArrayRef(dim_order_vec); + } + + // dim order size shall equal to input dim + ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == out.dim()); + + ET_LOG_AND_RETURN_IF_FALSE( + is_channels_last_dim_order(dim_order_ref.data(), dim_order_ref.size()) || + is_contiguous_dim_order(dim_order_ref.data(), dim_order_ref.size())); + + ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim()); + exec_aten::StridesType target_strides[kMaxNumOfDimensions]; + dim_order_to_stride_nocheck( + out.sizes().data(), + dim_order_ref.data(), + dim_order_ref.size(), + target_strides); + + for (size_t i = 0; i < dim_order_ref.size(); i++) { + ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]); + } + + return true; +} + +} // namespace + +/* + * Empty out tensor with specified dim order + * + * _empty_dim_order.out(SymInt[] size, *, int[]? dim_order=None, Tensor(a!) out) + * -> Tensor(a!) + */ +Tensor& _empty_dim_order_out( + KernelRuntimeContext& context, + IntArrayRef size, + OptionalIntArrayRef dim_order, + Tensor& out) { + (void)context; + + // Check if dim_order is valid + ET_KERNEL_CHECK( + context, + _check__empty_out_dim_order(dim_order, out), + InvalidArgument, + out); + + // Resize for dynamic shape + ET_KERNEL_CHECK_MSG( + context, + resize_tensor(out, size) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + return out; +} + +Tensor& _empty_dim_order_out( + IntArrayRef size, + OptionalIntArrayRef dim_order, + Tensor& out) { + executorch::runtime::KernelRuntimeContext ctx{}; + return _empty_dim_order_out(ctx, size, dim_order, out); +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/aten/cpu/targets.bzl b/kernels/aten/cpu/targets.bzl index bdd93bda9ed..bb7083c1f01 100644 --- a/kernels/aten/cpu/targets.bzl +++ b/kernels/aten/cpu/targets.bzl @@ -9,6 +9,9 @@ load("@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "d # ops, and must be split. They can, however, share common code via a library dep # if necessary. _EDGE_DIALECT_OPS = ( + op_target( + name = "op__empty_dim_order", + ), op_target( name = "op__to_dim_order_copy", deps = [ diff --git a/kernels/aten/edge_dialect_aten_op.yaml b/kernels/aten/edge_dialect_aten_op.yaml index 016f8dbfab5..d9de3f6dded 100644 --- a/kernels/aten/edge_dialect_aten_op.yaml +++ b/kernels/aten/edge_dialect_aten_op.yaml @@ -2,6 +2,11 @@ # # This yaml file contains operators that are defined by ExecuTorch and used in ATen mode. +- func: dim_order_ops::_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: torch::executor::_empty_dim_order_out + - func: dim_order_ops::_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/kernels/portable/cpu/op__empty_dim_order.cpp b/kernels/portable/cpu/op__empty_dim_order.cpp new file mode 100644 index 00000000000..a4d733662f1 --- /dev/null +++ b/kernels/portable/cpu/op__empty_dim_order.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using exec_aten::Tensor; +using OptionalIntArrayRef = exec_aten::OptionalArrayRef; +using DimOrderArrayRef = exec_aten::ArrayRef; + +namespace { + +bool _check__empty_out_dim_order(OptionalIntArrayRef dim_order, Tensor& out) { + DimOrderArrayRef out_dim_order = out.dim_order(); + + if (dim_order.has_value()) { + // out tensor's dim order shall equal to input dim order + IntArrayRef dim_order_ref = dim_order.value(); + + ET_LOG_AND_RETURN_IF_FALSE( + is_channels_last_dim_order( + dim_order.value().data(), dim_order.value().size()) || + is_contiguous_dim_order( + dim_order.value().data(), dim_order.value().size())); + + // Out tensor shall have same dim order as dim_order + ET_LOG_AND_RETURN_IF_FALSE(out_dim_order.size() == dim_order_ref.size()); + for (size_t i = 0; i < dim_order_ref.size(); i++) { + ET_LOG_AND_RETURN_IF_FALSE(out_dim_order[i] == dim_order_ref[i]); + } + } else { // dim_order is not set, out tensor should be contiguous memory + // format + ET_LOG_AND_RETURN_IF_FALSE( + is_contiguous_dim_order(out_dim_order.data(), out_dim_order.size())); + } + return true; +} + +} // namespace + +/* + * Empty out tensor with specified dim order + * + * _empty_dim_order.out(SymInt[] size, *, int[]? dim_order=None, Tensor(a!) out) + * -> Tensor(a!) + */ +Tensor& _empty_dim_order_out( + KernelRuntimeContext& context, + IntArrayRef size, + OptionalIntArrayRef dim_order, + Tensor& out) { + (void)context; + + // Check if dim_order is valid + _check__empty_out_dim_order(dim_order, out); + + // Resize for dynamic shape + ET_KERNEL_CHECK_MSG( + context, + resize_tensor(out, size) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index a5d60eb59e4..266b5e446fb 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -937,6 +937,11 @@ - arg_meta: null kernel_name: torch::executor::zeros_out +- func: dim_order_ops::_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: torch::executor::_empty_dim_order_out + - func: dim_order_ops::_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/kernels/test/op__empty_dim_order_test.cpp b/kernels/test/op__empty_dim_order_test.cpp new file mode 100644 index 00000000000..2857bac0458 --- /dev/null +++ b/kernels/test/op__empty_dim_order_test.cpp @@ -0,0 +1,161 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; +using exec_aten::DimOrderType; +using exec_aten::IntArrayRef; +using exec_aten::optional; +using exec_aten::OptionalArrayRef; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::testing::TensorFactory; + +class OpEmptyDimOrderOutTest : public OperatorTest { + protected: + Tensor& op_empty_dim_order_out( + IntArrayRef size, + OptionalArrayRef dim_order, + Tensor& out) { + return torch::executor::dim_order_ops::_empty_dim_order_outf( + context_, size, dim_order, out); + } + + template + void test_op_empty_dim_order_out(std::vector&& size_int32_t) { + TensorFactory tf; + std::vector sizes(size_int32_t.begin(), size_int32_t.end()); + auto aref = exec_aten::ArrayRef(sizes.data(), sizes.size()); + OptionalArrayRef dim_order; + Tensor out = tf.ones(size_int32_t); + + op_empty_dim_order_out(aref, dim_order, out); + } + + void too_short_dim_order_die() { + TensorFactory tf; + + int64_t sizes[3] = {3, 2, 4}; + auto sizes_aref = exec_aten::ArrayRef(sizes); + + int64_t raw_dim_order[2] = {0, 1}; + auto dim_order = OptionalArrayRef(raw_dim_order); + Tensor out = + tf.ones({3, 2, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + ET_EXPECT_KERNEL_FAILURE( + context_, op_empty_dim_order_out(sizes_aref, dim_order, out)); + } + + void illegal_dim_order_die() { + TensorFactory tf; + + int64_t sizes[2] = {3, 2}; + auto sizes_aref = exec_aten::ArrayRef(sizes); + + int64_t raw_dim_order[2] = {1, 2}; + auto dim_order = OptionalArrayRef(raw_dim_order); + Tensor out = + tf.ones({3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + ET_EXPECT_KERNEL_FAILURE( + context_, op_empty_dim_order_out(sizes_aref, dim_order, out)); + } + + void wrong_dim_order_die() { + TensorFactory tf; + + int64_t sizes[4] = {3, 2, 4, 5}; + auto sizes_aref = exec_aten::ArrayRef(sizes); + + // should be {0, 2, 3, 1} + int64_t raw_dim_order[4] = {0, 1, 2, 3}; + auto dim_order = OptionalArrayRef(raw_dim_order); + Tensor out = tf.full_channels_last( + {3, 2, 4, 5}, 1, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + ET_EXPECT_KERNEL_FAILURE( + context_, op_empty_dim_order_out(sizes_aref, dim_order, out)); + } +}; + +#define GENERATE_TEST(_, DTYPE) \ + TEST_F(OpEmptyDimOrderOutTest, DTYPE##Tensors) { \ + test_op_empty_dim_order_out({2, 3, 4}); \ + test_op_empty_dim_order_out({2, 0, 4}); \ + test_op_empty_dim_order_out({}); \ + } + +ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_TEST) + +TEST_F(OpEmptyDimOrderOutTest, DynamicShapeUpperBoundSameAsExpected) { + TensorFactory tf; + + int64_t sizes[2] = {3, 2}; + auto sizes_aref = exec_aten::ArrayRef(sizes); + OptionalArrayRef dim_order; + Tensor out = + tf.ones({3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + op_empty_dim_order_out(sizes_aref, dim_order, out); +} + +TEST_F(OpEmptyDimOrderOutTest, ContiguousDimOrderSuccees) { + TensorFactory tf; + + int64_t sizes[2] = {3, 2}; + auto sizes_aref = exec_aten::ArrayRef(sizes); + + int64_t raw_dim_order[2] = {0, 1}; + auto dim_order = OptionalArrayRef(raw_dim_order); + Tensor out = + tf.ones({3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + op_empty_dim_order_out(sizes_aref, dim_order, out); +} + +TEST_F(OpEmptyDimOrderOutTest, ChannelsLastsDimOrderSuccees) { + TensorFactory tf; + + int64_t sizes[4] = {3, 2, 4, 5}; + auto sizes_aref = exec_aten::ArrayRef(sizes); + + int64_t raw_dim_order[4] = {0, 2, 3, 1}; + auto dim_order = OptionalArrayRef(raw_dim_order); + Tensor out = tf.full_channels_last( + {3, 2, 4, 5}, 1, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + op_empty_dim_order_out(sizes_aref, dim_order, out); +} + +TEST_F(OpEmptyDimOrderOutTest, DynamicShapeUpperBoundLargerThanExpected) { + TensorFactory tf; + + int64_t sizes[2] = {3, 2}; + auto sizes_aref = exec_aten::ArrayRef(sizes); + OptionalArrayRef dim_order; + Tensor out = + tf.ones({10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + op_empty_dim_order_out(sizes_aref, dim_order, out); +} + +TEST_F(OpEmptyDimOrderOutTest, DynamicShapeUnbound) { + if (!torch::executor::testing::SupportedFeatures::get()->output_resize) { + GTEST_SKIP() << "Dynamic shape unbound not supported"; + } + TensorFactory tf; + + int64_t sizes[2] = {3, 2}; + auto sizes_aref = exec_aten::ArrayRef(sizes); + OptionalArrayRef dim_order; + Tensor out = + tf.ones({1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND); + op_empty_dim_order_out(sizes_aref, dim_order, out); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 77b18a48146..2e7c34f147d 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -174,6 +174,7 @@ def define_common_targets(): codegen_function_header_wrapper("executorch/kernels/test/custom_kernel_example", "custom_kernel_example") _common_op_test("op__to_dim_order_copy_test", ["aten", "portable"]) + _common_op_test("op__empty_dim_order_test", ["aten", "portable"]) _common_op_test("op_abs_test", ["aten", "portable"]) _common_op_test("op_acos_test", ["aten", "portable"]) _common_op_test("op_acosh_test", ["aten", "portable"]) diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 53698e7f216..b88e83b0586 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1252,6 +1252,12 @@ ATEN_OPS = ( op_target( name = "op_zeros", ), + op_target( + name = "op__empty_dim_order", + deps = [ + ":scalar_utils", + ], + ), op_target( name = "op__to_dim_order_copy", deps = [