diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index db59af17e6d..f2f18f51c85 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -243,6 +243,8 @@ - op: masked_scatter.out +- op: masked_select.out + - op: max_pool2d_with_indices.out - op: max.dim_max diff --git a/kernels/portable/cpu/op_masked_select.cpp b/kernels/portable/cpu/op_masked_select.cpp new file mode 100644 index 00000000000..b176000f6c8 --- /dev/null +++ b/kernels/portable/cpu/op_masked_select.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& masked_select_out( + KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& mask, + Tensor& out) { + ScalarType in_type = in.scalar_type(); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_realhbbf16_type(in), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, mask.scalar_type() == ScalarType::Bool, InvalidArgument, out); + ET_KERNEL_CHECK(ctx, out.scalar_type() == in_type, InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, mask, out), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, tensors_are_broadcastable_between(in, mask), InvalidArgument, out); + + // If input or mask is empty, the output should be empty + if (in.numel() == 0 || mask.numel() == 0) { + ET_KERNEL_CHECK( + ctx, resize_tensor(out, {0}) == Error::Ok, InvalidArgument, out); + return out; + } + + // Compute the shape resulting from broadcasting the mask against the input + size_t broadcast_ndim = 0; + Tensor::SizesType broadcast_sizes[kTensorDimensionLimit]; + Error err = get_broadcast_target_size( + in, mask, broadcast_sizes, kTensorDimensionLimit, &broadcast_ndim); + if (err != Error::Ok) { + ET_KERNEL_CHECK_MSG( + ctx, false, InvalidArgument, out, "Failed to broadcast input and mask"); + } + size_t broadcast_numel = 1; + for (size_t i = 0; i < broadcast_ndim; i++) { + broadcast_numel *= broadcast_sizes[i]; + } + + // Compute the number of out elements + size_t mask_true_count = 0; + const bool* const mask_data = mask.const_data_ptr(); + for (size_t i = 0; i < mask.numel(); ++i) { + if (mask_data[i]) { + mask_true_count++; + } + } + Tensor::SizesType out_numel = + mask_true_count * (broadcast_numel / mask.numel()); + + // Resize the out tensor + ET_KERNEL_CHECK( + ctx, resize_tensor(out, {out_numel}) == Error::Ok, InvalidArgument, out); + + const char* const in_data = + reinterpret_cast(in.const_data_ptr()); + char* const out_data = reinterpret_cast(out.mutable_data_ptr()); + const auto elem_size = in.element_size(); + + // Figure out if `in` is broadcasted + bool in_is_broadcasted = false; + if (in.dim() != broadcast_ndim) { + in_is_broadcasted = true; + } else { + for (size_t i = 0; i < in.dim(); ++i) { + if (in.size(i) != broadcast_sizes[i]) { + in_is_broadcasted = true; + } + } + } + + // Figure out if `mask` is broadcasted + bool mask_is_broadcasted = false; + if (mask.dim() != broadcast_ndim) { + mask_is_broadcasted = true; + } else { + for (size_t i = 0; i < mask.dim(); ++i) { + if (mask.size(i) != broadcast_sizes[i]) { + mask_is_broadcasted = true; + } + } + } + + // Figure out if either `in` or `mask` is broadcasted + bool any_is_broadcasted = (in_is_broadcasted || mask_is_broadcasted); + + size_t out_ix = 0; + for (size_t i = 0; i < broadcast_numel; ++i) { + size_t in_linear_index = i; + size_t mask_linear_index = i; + + // If either `in` or `mask` is broadcasted, we need to compute the indexes + // in the broadcasted space. + if (any_is_broadcasted) { + size_t broadcast_indexes[kTensorDimensionLimit]; + delinearize_index( + i, + {broadcast_sizes, broadcast_ndim}, + broadcast_indexes, + kTensorDimensionLimit); + + if (in_is_broadcasted) { + in_linear_index = + linearize_access_indexes(broadcast_indexes, broadcast_ndim, in); + } + if (mask_is_broadcasted) { + mask_linear_index = + linearize_access_indexes(broadcast_indexes, broadcast_ndim, mask); + } + } + + // If the mask is true, copy the value from `in` to `out` and increment the + // `out_ix` + if (mask_data[mask_linear_index]) { + memcpy( + out_data + out_ix * elem_size, + in_data + in_linear_index * elem_size, + elem_size); + out_ix++; + } + } + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index d1eb8b8a3bf..a5d60eb59e4 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -547,6 +547,11 @@ - arg_meta: null kernel_name: torch::executor::masked_scatter_out +- op: masked_select.out + kernels: + - arg_meta: null + kernel_name: torch::executor::masked_select_out + - op: max.dim_max kernels: - arg_meta: null diff --git a/kernels/test/op_masked_select_test.cpp b/kernels/test/op_masked_select_test.cpp new file mode 100644 index 00000000000..2a7791e9c18 --- /dev/null +++ b/kernels/test/op_masked_select_test.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::testing::SupportedFeatures; +using torch::executor::testing::TensorFactory; + +class OpMaskedSelectOutTest : public OperatorTest { + protected: + Tensor& + op_masked_select_out(const Tensor& in, const Tensor& mask, Tensor& out) { + return torch::executor::aten::masked_select_outf(context_, in, mask, out); + } +}; + +TEST_F(OpMaskedSelectOutTest, SmokeTest) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true}); + Tensor out = tf.zeros({3}); + + op_masked_select_out(in, mask, out); + EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 4, 6})); +} + +TEST_F(OpMaskedSelectOutTest, BroadcastInput) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({3}, {1, 2, 3}); + Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true}); + Tensor out = tf.zeros({3}); + + op_masked_select_out(in, mask, out); + EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 1, 3})); +} + +TEST_F(OpMaskedSelectOutTest, BroadcastMask) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + Tensor mask = tfBool.make({3}, {false, true, false}); + + Tensor out = tf.zeros({2}); + + op_masked_select_out(in, mask, out); + EXPECT_TENSOR_EQ(out, tf.make({2}, {2, 5})); +} + +TEST_F(OpMaskedSelectOutTest, BroadcastInputAndMask) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.ones({2, 3, 4, 1}); + Tensor mask = tfBool.ones({2, 1, 1, 5}); + Tensor out = tf.zeros({120}); + + op_masked_select_out(in, mask, out); + EXPECT_TENSOR_EQ(out, tf.ones({120})); +} + +TEST_F(OpMaskedSelectOutTest, EmptyInput) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 0}, {}); + Tensor mask = tfBool.make({2, 1}, {true, true}); + Tensor out = tf.zeros({0}); + + op_masked_select_out(in, mask, out); + EXPECT_TENSOR_EQ(out, tf.make({0}, {})); +} + +TEST_F(OpMaskedSelectOutTest, EmptyMask) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 1}, {100, 200}); + Tensor mask = tfBool.make({2, 0}, {}); + Tensor out = tf.zeros({0}); + + op_masked_select_out(in, mask, out); + EXPECT_TENSOR_EQ(out, tf.make({0}, {})); +} + +TEST_F(OpMaskedSelectOutTest, EmptyInputAndMask) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor in = tf.make({2, 0}, {}); + Tensor mask = tfBool.make({0}, {}); + Tensor out = tf.zeros({0}); + + op_masked_select_out(in, mask, out); + EXPECT_TENSOR_EQ(out, tf.make({0}, {})); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 91b3ba89fde..ce15a578adf 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -255,6 +255,7 @@ def define_common_targets(): _common_op_test("op_lt_test", ["aten", "portable"]) _common_op_test("op_masked_fill_test", ["aten", "portable"]) _common_op_test("op_masked_scatter_test", ["aten", "portable"]) + _common_op_test("op_masked_select_test", ["aten", "portable"]) _common_op_test("op_max_test", ["aten", "portable"]) _common_op_test("op_max_pool2d_with_indices_test", ["aten", "portable"]) _common_op_test("op_maximum_test", ["aten", "portable"]) diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index f63932d4840..ab8fc63a2af 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -789,6 +789,12 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:broadcast_util", ], ), + op_target( + name = "op_masked_select", + deps = [ + "//executorch/kernels/portable/cpu/util:broadcast_util", + ], + ), op_target( name = "op_max", deps = [