diff --git a/kernels/portable/cpu/op_index.cpp b/kernels/portable/cpu/op_index.cpp index 8fbf903400a..6ce9fb375de 100644 --- a/kernels/portable/cpu/op_index.cpp +++ b/kernels/portable/cpu/op_index.cpp @@ -1,11 +1,13 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -47,6 +49,23 @@ bool check_fast_path_conditions( if (index.dim() != 1) { return false; } + + // Fast path only supports non-negative indices. + if (ix_type == ScalarType::Int) { + const int32_t* const data = index.const_data_ptr(); + if (std::any_of(data, data + index.numel(), [](const auto x) { + return x < 0; + })) { + return false; + } + } else { // ScalarType::Long + const int64_t* const data = index.const_data_ptr(); + if (std::any_of(data, data + index.numel(), [](const auto x) { + return x < 0; + })) { + return false; + } + } } } diff --git a/kernels/test/op_index_test.cpp b/kernels/test/op_index_test.cpp index f3e1d9081c0..787eb4612d8 100644 --- a/kernels/test/op_index_test.cpp +++ b/kernels/test/op_index_test.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -480,6 +481,36 @@ TEST_F(OpIndexTensorOutTest, AllDtypesSupportedForIndex) { test_dtype(); } +TEST_F(OpIndexTensorOutTest, NegativeIndexSupportedForLong) { + TensorFactory tf; + TensorFactory tfl; + + Tensor x = tf.make({3}, {1., 2., 3.}); + Tensor out = tf.zeros({1}); + Tensor expected = tf.make({1}, {3.}); + + std::array, 1> indices = { + optional(tfl.make({1}, {-1}))}; + + Tensor ret = op_index_tensor_out(x, indices, out); + EXPECT_TENSOR_EQ(ret, expected); +} + +TEST_F(OpIndexTensorOutTest, NegativeIndexSupportedForInt) { + TensorFactory tf; + TensorFactory tfi; + + Tensor x = tf.make({3}, {1., 2., 3.}); + Tensor out = tf.zeros({1}); + Tensor expected = tf.make({1}, {3.}); + + std::array, 1> indices = { + optional(tfi.make({1}, {-1}))}; + + Tensor ret = op_index_tensor_out(x, indices, out); + EXPECT_TENSOR_EQ(ret, expected); +} + // // Death Tests //