-
Notifications
You must be signed in to change notification settings - Fork 722
Add sort util for 1D tensors #2786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #include "executorch/kernels/portable/cpu/util/allocate_tensor_util.h" | ||
|
|
||
|
|
||
| namespace torch { | ||
| namespace executor { | ||
|
|
||
| using Tensor = exec_aten::Tensor; | ||
| using ScalarType = exec_aten::ScalarType; | ||
|
|
||
| Tensor allocate_tensor( | ||
| KernelRuntimeContext& ctx, | ||
| const ArrayRef<Tensor::SizesType>& sizes, | ||
| const ArrayRef<Tensor::DimOrderType>& dim_order, | ||
| const ArrayRef<Tensor::StridesType>& strides, | ||
| const ScalarType& dtype) { | ||
| int dim = sizes.size(); | ||
| int size_nbytes = dim * sizeof(Tensor::SizesType); | ||
| Result<void*> temp_mem_res_size = ctx.allocate_temp(size_nbytes); | ||
| void* size_data_ptr = | ||
| temp_mem_res_size.ok() ? temp_mem_res_size.get() : nullptr; | ||
| ET_CHECK_MSG(size_data_ptr != nullptr, "Failed to malloc for size bytes"); | ||
| memcpy(size_data_ptr, sizes.data(), size_nbytes); | ||
|
|
||
| // TODO(T145322324): can we remove the static cast once size is unsigned? | ||
| size_t dim_order_nbytes = | ||
| static_cast<size_t>(dim) * sizeof(Tensor::DimOrderType); | ||
| Result<void*> temp_mem_res_dim_order = ctx.allocate_temp(dim_order_nbytes); | ||
| void* dim_order_data_ptr = | ||
| temp_mem_res_dim_order.ok() ? temp_mem_res_dim_order.get() : nullptr; | ||
| ET_CHECK_MSG( | ||
| dim_order_data_ptr != nullptr, "Failed to malloc for dim order bytes"); | ||
| memcpy(dim_order_data_ptr, dim_order.data(), dim_order_nbytes); | ||
|
|
||
| int strides_nbytes = dim * sizeof(Tensor::StridesType); | ||
| Result<void*> temp_mem_res_strides = ctx.allocate_temp(strides_nbytes); | ||
| void* strides_data_ptr = | ||
| temp_mem_res_strides.ok() ? temp_mem_res_strides.get() : nullptr; | ||
| printf("strides_data_ptr: %p\n", strides_data_ptr); | ||
| fflush(stdout); | ||
| ET_CHECK_MSG( | ||
| strides_data_ptr != nullptr, "Failed to malloc for strides bytes"); | ||
| memcpy(strides_data_ptr, strides.data(), strides_nbytes); | ||
|
|
||
| Result<void*> temp_mem_res_tensor = ctx.allocate_temp(sizeof(TensorImpl)); | ||
| auto tensor_impl = static_cast<TensorImpl*>( | ||
| temp_mem_res_tensor.ok() ? temp_mem_res_tensor.get() : nullptr); | ||
| ET_CHECK_MSG(tensor_impl != nullptr, "Failed to malloc for data TensorImpl"); | ||
|
|
||
| new (tensor_impl) TensorImpl( | ||
| dtype, | ||
| dim, | ||
| reinterpret_cast<Tensor::SizesType*>(size_data_ptr), | ||
| nullptr, | ||
| reinterpret_cast<Tensor::DimOrderType*>(dim_order_data_ptr), | ||
| reinterpret_cast<Tensor::StridesType*>(strides_data_ptr)); | ||
|
|
||
| Result<void*> temp_mem_res_data = ctx.allocate_temp(tensor_impl->nbytes()); | ||
| void* data_ptr = temp_mem_res_data.ok() ? temp_mem_res_data.get() : nullptr; | ||
| ET_CHECK_MSG(data_ptr != nullptr, "Failed to malloc for data buffer"); | ||
| tensor_impl->set_data(data_ptr); | ||
|
|
||
| return Tensor{tensor_impl}; | ||
| } | ||
|
|
||
| } // namespace executor | ||
| } // namespace torch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <executorch/runtime/kernel/kernel_includes.h> | ||
|
|
||
| namespace torch { | ||
| namespace executor { | ||
|
|
||
| Tensor allocate_tensor( | ||
| KernelRuntimeContext& ctx, | ||
| const ArrayRef<Tensor::SizesType>& sizes, | ||
| const ArrayRef<Tensor::DimOrderType>& dim_order, | ||
| const ArrayRef<Tensor::StridesType>& strides, | ||
| const ScalarType& dtype); | ||
|
|
||
| } // namespace executor | ||
| } // namespace torch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #include "executorch/kernels/portable/cpu/util/sort_util.h" | ||
| #include <executorch/runtime/kernel/kernel_includes.h> | ||
| #include <algorithm> | ||
|
|
||
| namespace torch { | ||
| namespace executor { | ||
|
|
||
| using Tensor = exec_aten::Tensor; | ||
|
|
||
| Error sort_tensor( | ||
| const Tensor& tensor, | ||
| Tensor& sorted_tensor, | ||
| Tensor& sorted_indices, | ||
| bool descending) { | ||
| // Check if the input tensor is a valid input | ||
| ET_CHECK_MSG(tensor.dim() == 1, "Input tensor must be 1D"); | ||
|
|
||
| // Check if the output tensors are valid | ||
| ET_CHECK_MSG(sorted_tensor.dim() == 1, "Output tensor must be 1D"); | ||
| ET_CHECK_MSG(sorted_indices.dim() == 1, "Output tensor must be 1D"); | ||
|
|
||
| // Check if the output tensors have the same dtype | ||
| ET_CHECK_MSG( | ||
| tensor.scalar_type() == sorted_tensor.scalar_type(), | ||
| "Input and output tensors must have the same dtype"); | ||
| ET_CHECK_MSG( | ||
| tensor.scalar_type() == ScalarType::Float, | ||
| "Only float inputs are supported currently"); | ||
| ET_CHECK_MSG( | ||
| sorted_indices.scalar_type() == exec_aten::ScalarType::Long, | ||
| "Output tensor must be of type int64"); | ||
|
|
||
| // Get the number of elements in the tensor | ||
| int size = tensor.numel(); | ||
|
|
||
| // Create a tensor to store the indices | ||
| for (int i = 0; i < size; i++) { | ||
| sorted_indices.mutable_data_ptr<int64_t>()[i] = i; | ||
| } | ||
|
|
||
| // Sort the indices based on the corresponding tensor values | ||
| std::sort( | ||
| sorted_indices.mutable_data_ptr<int64_t>(), | ||
| sorted_indices.mutable_data_ptr<int64_t>() + size, | ||
| [&tensor, descending](int64_t i, int64_t j) { | ||
| if (descending) { | ||
| return tensor.const_data_ptr<float>()[i] > | ||
| tensor.const_data_ptr<float>()[j]; | ||
| } else { | ||
| return tensor.const_data_ptr<float>()[i] < | ||
| tensor.const_data_ptr<float>()[j]; | ||
| } | ||
| }); | ||
|
|
||
| // Rearrange the tensor values based on the sorted indices | ||
| for (int i = 0; i < size; i++) { | ||
| sorted_tensor.mutable_data_ptr<float>()[i] = tensor.const_data_ptr< | ||
| float>()[sorted_indices.const_data_ptr<int64_t>()[i]]; | ||
| } | ||
|
|
||
| return Error::Ok; | ||
| } | ||
|
|
||
| } // namespace executor | ||
| } // namespace torch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <executorch/runtime/core/exec_aten/exec_aten.h> | ||
|
|
||
| namespace torch { | ||
| namespace executor { | ||
|
|
||
| using Tensor = exec_aten::Tensor; | ||
|
|
||
| Error sort_tensor( | ||
| const Tensor& tensor, | ||
| Tensor& sorted_tensor, | ||
| Tensor& sorted_indice, | ||
| bool descending = false); | ||
|
|
||
| } // namespace executor | ||
| } // namespace torch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #include <gtest/gtest.h> | ||
|
|
||
| #include <executorch/kernels/portable/cpu/util/allocate_tensor_util.h> | ||
| #include <executorch/runtime/kernel/kernel_includes.h> | ||
| #include <executorch/runtime/platform/runtime.h> | ||
| #include <executorch/test/utils/DeathTest.h> | ||
| using ScalarType = exec_aten::ScalarType; | ||
|
|
||
| class AllocateTest : public ::testing::Test { | ||
| protected: | ||
| void SetUp() override { | ||
| // Since these tests cause ET_LOG to be called, the PAL must be initialized | ||
| // first. | ||
| torch::executor::runtime_init(); | ||
| } | ||
| }; | ||
|
|
||
| TEST(AllocateTest, AllocateTensor) { | ||
| uint8_t* temp_allocator_ptr = (uint8_t*)malloc(2048); | ||
| executorch::runtime::MemoryAllocator temp_allocator(2048, temp_allocator_ptr); | ||
| executorch::runtime::KernelRuntimeContext ctx(nullptr, &temp_allocator); | ||
|
|
||
| executorch::aten::SizesType sizes[3] = {1, 2, 3}; | ||
| executorch::aten::DimOrderType dim_order[3] = {0, 1, 2}; | ||
| executorch::aten::StridesType strides[3] = {3, 3, 1}; | ||
|
|
||
| torch::executor::ArrayRef<exec_aten::SizesType> sizes_ref(sizes, 3); | ||
| torch::executor::ArrayRef<exec_aten::StridesType> strides_ref(strides, 3); | ||
| torch::executor::ArrayRef<exec_aten::DimOrderType> dim_orders_ref( | ||
| dim_order, 3); | ||
|
|
||
| torch::executor::allocate_tensor( | ||
| ctx, sizes, dim_order, strides, ScalarType::Float); | ||
|
|
||
| free(temp_allocator_ptr); | ||
| } | ||
|
|
||
| TEST(AllocateTest, FailAllocateTensor) { | ||
| torch::executor::runtime_init(); | ||
|
|
||
| uint8_t* temp_allocator_ptr = (uint8_t*)malloc(16); | ||
| executorch::runtime::MemoryAllocator temp_allocator(16, temp_allocator_ptr); | ||
| executorch::runtime::KernelRuntimeContext ctx(nullptr, &temp_allocator); | ||
|
|
||
| executorch::aten::SizesType sizes[3] = {1, 2, 3}; | ||
| executorch::aten::DimOrderType dim_order[3] = {0, 1, 2}; | ||
| executorch::aten::StridesType strides[3] = {3, 3, 1}; | ||
|
|
||
| torch::executor::ArrayRef<exec_aten::SizesType> sizes_ref(sizes, 3); | ||
| torch::executor::ArrayRef<exec_aten::StridesType> strides_ref(strides, 3); | ||
| torch::executor::ArrayRef<exec_aten::DimOrderType> dim_orders_ref( | ||
| dim_order, 3); | ||
|
|
||
| ET_EXPECT_DEATH( | ||
| torch::executor::allocate_tensor( | ||
| ctx, sizes, dim_order, strides, ScalarType::Float), | ||
| "Failed to malloc"); | ||
|
|
||
| free(temp_allocator_ptr); | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #include <executorch/kernels/portable/cpu/util/sort_util.h> | ||
| #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h> | ||
| #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h> | ||
| #include <executorch/runtime/core/exec_aten/util/tensor_util.h> | ||
| #include <executorch/test/utils/DeathTest.h> | ||
|
|
||
| #include <gtest/gtest.h> | ||
|
|
||
| using namespace ::testing; | ||
| using exec_aten::ScalarType; | ||
| using exec_aten::Tensor; | ||
| using torch::executor::ArrayRef; | ||
| using torch::executor::testing::TensorFactory; | ||
|
|
||
| TEST(SortUtilTest, SortTensorTest) { | ||
| TensorFactory<ScalarType::Float> tf; | ||
| TensorFactory<ScalarType::Long> lf; | ||
|
|
||
| Tensor a = tf.make({4}, {3, 2, 1, 4}); | ||
| Tensor b = tf.zeros({4}); | ||
| Tensor c = lf.zeros({4}); | ||
|
|
||
| // Ascending order sort test | ||
| sort_tensor(a, b, c); | ||
|
|
||
| Tensor expected = tf.make({4}, {1, 2, 3, 4}); | ||
| Tensor expected_indices = lf.make({4}, {2, 1, 0, 3}); | ||
| EXPECT_TENSOR_EQ(b, expected); | ||
| EXPECT_TENSOR_EQ(c, expected_indices); | ||
|
|
||
| // Descending order sort test | ||
| sort_tensor(a, b, c, true); | ||
| expected = tf.make({4}, {4, 3, 2, 1}); | ||
| expected_indices = lf.make({4}, {3, 0, 1, 2}); | ||
| EXPECT_TENSOR_EQ(b, expected); | ||
| EXPECT_TENSOR_EQ(c, expected_indices); | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice test!