diff --git a/kernels/portable/cpu/op_index.cpp b/kernels/portable/cpu/op_index.cpp index a81ce6ad737..f198a73b45a 100644 --- a/kernels/portable/cpu/op_index.cpp +++ b/kernels/portable/cpu/op_index.cpp @@ -22,21 +22,159 @@ namespace native { using Tensor = executorch::aten::Tensor; using TensorOptList = executorch::aten::ArrayRef>; -Tensor& index_Tensor_out( +namespace { + +bool check_fast_path_conditions( + ET_UNUSED const Tensor& in, + TensorOptList indices, + size_t* dim) { + bool found_index = false; + for (const auto i : c10::irange(indices.size())) { + if (indices[i].has_value()) { + *dim = i; + // Fast path only supports a single non-null index tensor + if (found_index) { + return false; + } + found_index = true; + const Tensor& index = indices[i].value(); + ScalarType ix_type = index.scalar_type(); + // Fast path only supports only supports Long or Int index tensors + if (ix_type != ScalarType::Long && ix_type != ScalarType::Int) { + return false; + } + // Fast path only supports a 1-dimensional index tensor + if (index.dim() != 1) { + return false; + } + } + } + + // Fast path only supports needs at least one non-null index tensor + if (!found_index) { + return false; + } + + return true; +} + +bool check_fast_path_args( + const Tensor& in, + TensorOptList indices, + size_t dim, + Tensor& out) { + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out)); + + ET_CHECK_OR_RETURN_FALSE( + static_cast(indices.size()) <= in.dim(), + "Indexing too many dimensions"); + + const Tensor& index = indices[dim].value(); + + bool is_valid_index = true; + ET_SWITCH_TWO_TYPES( + Long, Int, index.scalar_type(), ctx, "index_put_", CTYPE, [&]() { + const CTYPE* const index_arr = index.const_data_ptr(); + for (const auto i : c10::irange(index.numel())) { + if (index_arr[i] < 0 || + index_arr[i] >= static_cast(in.size(dim))) { + ET_LOG( + Error, + "Index %" PRId64 + " out of range for tensor with size %zd" + " at dimension %zu", + static_cast(index_arr[i]), + in.size(dim), + dim); + is_valid_index = false; + break; + } + } + }); + + ET_CHECK_OR_RETURN_FALSE( + is_valid_index, + "Some index values are not within bounds of input tensor at indexed dim"); + + return true; +} + +Tensor& fast_path( KernelRuntimeContext& ctx, const Tensor& in, TensorOptList indices, + size_t dim, Tensor& out) { (void)ctx; ET_KERNEL_CHECK( - ctx, check_index_args(in, indices, out), InvalidArgument, out); + ctx, check_fast_path_args(in, indices, dim, out), InvalidArgument, out); + + const Tensor& index = indices[dim].value(); + ScalarType index_type = index.scalar_type(); + + if (out.dim() == 0) { + memcpy(out.mutable_data_ptr(), in.const_data_ptr(), out.nbytes()); + return out; + } + + size_t leading_dims = getLeadingDims(in, dim); + size_t trailing_dims = getTrailingDims(in, dim); + + if (leading_dims == 0 || trailing_dims == 0) { + return out; + } + + size_t in_dim_length = in.size(dim); + size_t out_dim_length = out.size(dim); + + size_t length_per_step = trailing_dims * in.element_size(); + + const char* in_data = in.const_data_ptr(); + char* out_data = out.mutable_data_ptr(); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "index.Tensor_out"; + + ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, op_name, CTYPE, [&]() { + const CTYPE* const index_arr = index.const_data_ptr(); + for (const auto i : c10::irange(leading_dims)) { + const char* src = in_data + i * in_dim_length * length_per_step; + char* dest = out_data + i * out_dim_length * length_per_step; + for (const auto j : c10::irange(out_dim_length)) { + const char* copy_src = src + index_arr[j] * length_per_step; + char* copy_dest = dest + j * length_per_step; + memcpy(copy_dest, copy_src, length_per_step); + } + } + }); + + return out; +} + +} // namespace + +Tensor& index_Tensor_out( + KernelRuntimeContext& ctx, + const Tensor& in, + TensorOptList indices, + Tensor& out) { + (void)ctx; ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + size_t dim = 0; + bool is_fast_path = check_fast_path_conditions(in, indices, &dim); + if (is_fast_path) { + return fast_path(ctx, in, indices, dim, out); + } + + ET_KERNEL_CHECK( + ctx, check_index_args(in, indices, out), InvalidArgument, out); + ScalarType in_type = in.scalar_type(); size_t block_count = count_index_blocks(indices); diff --git a/kernels/test/op_index_test.cpp b/kernels/test/op_index_test.cpp index 2471d44b0a3..0c8cecd9291 100644 --- a/kernels/test/op_index_test.cpp +++ b/kernels/test/op_index_test.cpp @@ -627,3 +627,61 @@ TEST_F(OpIndexTensorOutTest, UpperBoundOutTensor) { EXPECT_TENSOR_EQ(out, ret); EXPECT_TENSOR_EQ(ret, expected); } + +TEST_F(OpIndexTensorOutTest, FastPath) { + TensorFactory tf; + TensorFactory tfl; + + // clang-format off + Tensor x = tf.make( + {2, 3, 4}, + { + // [0, :, :] + 1., 2., 3., 4., // [0, 0, :] + 5., 6., 7., 8., // [0, 1, :] + 9., 10., 11., 12., // [0, 2, :] + + // [1, :, :] + -1., -2., -3., -4., // [1, 0, :] + -5., -6., -7., -8., // [1, 1, :] + -9., -10., -11., -12., // [1, 2, :] + }); + // clang-format on + + optional indices[] = { + optional(), + optional(), + optional(tfl.make({3}, {2, 0, 1}))}; + + Tensor out = tf.zeros({2, 3, 3}); + // clang-format off + Tensor expected = tf.make( + {2, 3, 3}, + { + 3., 1., 2., + 7., 5., 6., + 11., 9., 10., + + -3., -1., -2., + -7., -5., -6., + -11., -9., -10., + }); + // clang-format on + + op_index_tensor_out(x, indices, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpIndexTensorOutTest, FastPathZeroDim) { + TensorFactory tf; + TensorFactory tfl; + + Tensor x = tf.ones({0}); + optional indices[] = {optional(tfl.zeros({0}))}; + Tensor out = tf.zeros({0}); + Tensor expected = tf.ones({0}); + op_index_tensor_out(x, indices, out); + + EXPECT_TENSOR_EQ(out, expected); +}