Skip to content

Commit e52a09f

Browse files
tarun292facebook-github-bot
authored andcommitted
Add sort util for 1D tensors (#2786)
Summary: Pull Request resolved: #2786 This diff adds a simple sort utility that sorts a tensor's values and returns the sorted values and the sorted indices in the out tensors that are provided. There are currently two limitations to this sort: - It only supports 1D tensors currently, has to be extended to support 2D and greater tensors. - Input types are assumed to be float and it currently asserts on that. This has to be templatized to support all dtypes. Reviewed By: iseeyuan Differential Revision: D55577025
1 parent 4e18b4b commit e52a09f

File tree

5 files changed

+162
-0
lines changed

5 files changed

+162
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "executorch/kernels/portable/cpu/util/sort_util.h"
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <algorithm>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
using Tensor = exec_aten::Tensor;
17+
18+
Error sort_tensor(
19+
const Tensor& tensor,
20+
Tensor& sorted_tensor,
21+
Tensor& sorted_indices,
22+
bool descending) {
23+
// Check if the input tensor is a valid input
24+
ET_CHECK_MSG(tensor.dim() == 1, "Input tensor must be 1D");
25+
26+
// Check if the output tensors are valid
27+
ET_CHECK_MSG(sorted_tensor.dim() == 1, "Output tensor must be 1D");
28+
ET_CHECK_MSG(sorted_indices.dim() == 1, "Output tensor must be 1D");
29+
30+
// Check if the output tensors have the same dtype
31+
ET_CHECK_MSG(
32+
tensor.scalar_type() == sorted_tensor.scalar_type(),
33+
"Input and output tensors must have the same dtype");
34+
ET_CHECK_MSG(
35+
tensor.scalar_type() == ScalarType::Float,
36+
"Only float inputs are supported currently");
37+
ET_CHECK_MSG(
38+
sorted_indices.scalar_type() == exec_aten::ScalarType::Long,
39+
"Output tensor must be of type int64");
40+
41+
// Get the number of elements in the tensor
42+
int size = tensor.numel();
43+
44+
// Create a tensor to store the indices
45+
for (int i = 0; i < size; i++) {
46+
sorted_indices.mutable_data_ptr<int64_t>()[i] = i;
47+
}
48+
49+
// Sort the indices based on the corresponding tensor values
50+
std::sort(
51+
sorted_indices.mutable_data_ptr<int64_t>(),
52+
sorted_indices.mutable_data_ptr<int64_t>() + size,
53+
[&tensor, descending](int64_t i, int64_t j) {
54+
if (descending) {
55+
return tensor.const_data_ptr<float>()[i] >
56+
tensor.const_data_ptr<float>()[j];
57+
} else {
58+
return tensor.const_data_ptr<float>()[i] <
59+
tensor.const_data_ptr<float>()[j];
60+
}
61+
});
62+
63+
// Rearrange the tensor values based on the sorted indices
64+
for (int i = 0; i < size; i++) {
65+
sorted_tensor.mutable_data_ptr<float>()[i] = tensor.const_data_ptr<
66+
float>()[sorted_indices.const_data_ptr<int64_t>()[i]];
67+
}
68+
69+
return Error::Ok;
70+
}
71+
72+
} // namespace executor
73+
} // namespace torch
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
using Tensor = exec_aten::Tensor;
17+
18+
Error sort_tensor(
19+
const Tensor& tensor,
20+
Tensor& sorted_tensor,
21+
Tensor& sorted_indice,
22+
bool descending = false);
23+
24+
} // namespace executor
25+
} // namespace torch

kernels/portable/cpu/util/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,17 @@ def define_common_targets():
247247
visibility = ["//executorch/kernels/portable/cpu/..."],
248248
)
249249

250+
runtime.cxx_library(
251+
name = "sort_util",
252+
srcs = ["sort_util.cpp"],
253+
exported_headers = ["sort_util.h"],
254+
deps = [
255+
"//executorch/runtime/kernel:kernel_includes",
256+
"//executorch/runtime/core/exec_aten/util:tensor_util",
257+
],
258+
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/torchvision/..."],
259+
)
260+
250261
# Utility functions that can be used by operators that perform reduction
251262
for aten_mode in [True, False]:
252263
suffix = "_aten" if aten_mode else ""
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/sort_util.h>
10+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
11+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
12+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13+
#include <executorch/test/utils/DeathTest.h>
14+
15+
#include <gtest/gtest.h>
16+
17+
using namespace ::testing;
18+
using exec_aten::ScalarType;
19+
using exec_aten::Tensor;
20+
using torch::executor::ArrayRef;
21+
using torch::executor::testing::TensorFactory;
22+
23+
TEST(SortUtilTest, SortTensorTest) {
24+
TensorFactory<ScalarType::Float> tf;
25+
TensorFactory<ScalarType::Long> lf;
26+
27+
Tensor a = tf.make({4}, {3, 2, 1, 4});
28+
Tensor b = tf.zeros({4});
29+
Tensor c = lf.zeros({4});
30+
31+
// Ascending order sort test
32+
sort_tensor(a, b, c);
33+
34+
Tensor expected = tf.make({4}, {1, 2, 3, 4});
35+
Tensor expected_indices = lf.make({4}, {2, 1, 0, 3});
36+
EXPECT_TENSOR_EQ(b, expected);
37+
EXPECT_TENSOR_EQ(c, expected_indices);
38+
39+
// Descending order sort test
40+
sort_tensor(a, b, c, true);
41+
expected = tf.make({4}, {4, 3, 2, 1});
42+
expected_indices = lf.make({4}, {3, 0, 1, 2});
43+
EXPECT_TENSOR_EQ(b, expected);
44+
EXPECT_TENSOR_EQ(c, expected_indices);
45+
}

kernels/portable/cpu/util/test/targets.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,13 @@ def define_common_targets():
2929
"//executorch/runtime/core/exec_aten:lib",
3030
"//executorch/kernels/portable/cpu/util:allocate_tensor_util",
3131
"//executorch/runtime/kernel:kernel_includes",
32+
33+
runtime.cxx_test(
34+
name = "sort_util_test",
35+
srcs = ["sort_util_test.cpp"],
36+
deps = [
37+
"//executorch/runtime/core/exec_aten:lib",
38+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
39+
"//executorch/kernels/portable/cpu/util:sort_util",
3240
],
3341
)

0 commit comments

Comments
 (0)