Skip to content

Commit f2ee743

Browse files
tarun292facebook-github-bot
authored andcommitted
Add allocate tensor util that uses temp allocator
Differential Revision: D64072692
1 parent f663ba6 commit f2ee743

File tree

5 files changed

+182
-0
lines changed

5 files changed

+182
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "executorch/kernels/portable/cpu/util/allocate_tensor_util.h"
10+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
15+
using Tensor = exec_aten::Tensor;
16+
using ScalarType = exec_aten::ScalarType;
17+
18+
Tensor allocate_tensor(
19+
KernelRuntimeContext& ctx,
20+
const ArrayRef<Tensor::SizesType>& sizes,
21+
const ArrayRef<Tensor::DimOrderType>& dim_order,
22+
const ArrayRef<Tensor::StridesType>& strides,
23+
const ScalarType& dtype) {
24+
int dim = sizes.size();
25+
int size_nbytes = dim * sizeof(Tensor::SizesType);
26+
Result<void*> temp_mem_res_size = ctx.allocate_temp(size_nbytes);
27+
void* size_data_ptr =
28+
temp_mem_res_size.ok() ? temp_mem_res_size.get() : nullptr;
29+
ET_CHECK_MSG(size_data_ptr != nullptr, "Failed to malloc for size bytes");
30+
memcpy(size_data_ptr, sizes.data(), size_nbytes);
31+
32+
// TODO(T145322324): can we remove the static cast once size is unsigned?
33+
size_t dim_order_nbytes =
34+
static_cast<size_t>(dim) * sizeof(Tensor::DimOrderType);
35+
Result<void*> temp_mem_res_dim_order = ctx.allocate_temp(dim_order_nbytes);
36+
void* dim_order_data_ptr =
37+
temp_mem_res_dim_order.ok() ? temp_mem_res_dim_order.get() : nullptr;
38+
ET_CHECK_MSG(
39+
dim_order_data_ptr != nullptr, "Failed to malloc for dim order bytes");
40+
memcpy(dim_order_data_ptr, dim_order.data(), dim_order_nbytes);
41+
42+
int strides_nbytes = dim * sizeof(Tensor::StridesType);
43+
Result<void*> temp_mem_res_strides = ctx.allocate_temp(strides_nbytes);
44+
void* strides_data_ptr =
45+
temp_mem_res_strides.ok() ? temp_mem_res_strides.get() : nullptr;
46+
printf("strides_data_ptr: %p\n", strides_data_ptr);
47+
fflush(stdout);
48+
ET_CHECK_MSG(
49+
strides_data_ptr != nullptr, "Failed to malloc for strides bytes");
50+
memcpy(strides_data_ptr, strides.data(), strides_nbytes);
51+
52+
Result<void*> temp_mem_res_tensor = ctx.allocate_temp(sizeof(TensorImpl));
53+
auto tensor_impl = static_cast<TensorImpl*>(
54+
temp_mem_res_tensor.ok() ? temp_mem_res_tensor.get() : nullptr);
55+
ET_CHECK_MSG(tensor_impl != nullptr, "Failed to malloc for data TensorImpl");
56+
57+
new (tensor_impl) TensorImpl(
58+
dtype,
59+
dim,
60+
reinterpret_cast<Tensor::SizesType*>(size_data_ptr),
61+
nullptr,
62+
reinterpret_cast<Tensor::DimOrderType*>(dim_order_data_ptr),
63+
reinterpret_cast<Tensor::StridesType*>(strides_data_ptr));
64+
65+
Result<void*> temp_mem_res_data = ctx.allocate_temp(tensor_impl->nbytes());
66+
void* data_ptr = temp_mem_res_data.ok() ? temp_mem_res_data.get() : nullptr;
67+
ET_CHECK_MSG(data_ptr != nullptr, "Failed to malloc for data buffer");
68+
tensor_impl->set_data(data_ptr);
69+
70+
return Tensor{tensor_impl};
71+
}
72+
73+
} // namespace executor
74+
} // namespace torch
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#pragma once
4+
5+
#include <executorch/runtime/kernel/kernel_includes.h>
6+
7+
namespace torch {
8+
namespace executor {
9+
10+
Tensor allocate_tensor(
11+
KernelRuntimeContext& ctx,
12+
const ArrayRef<Tensor::SizesType>& sizes,
13+
const ArrayRef<Tensor::DimOrderType>& dim_order,
14+
const ArrayRef<Tensor::StridesType>& strides,
15+
const ScalarType& dtype);
16+
17+
} // namespace executor
18+
} // namespace torch

kernels/portable/cpu/util/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,16 @@ def define_common_targets():
237237
visibility = ["//executorch/kernels/portable/cpu/..."],
238238
)
239239

240+
runtime.cxx_library(
241+
name = "allocate_tensor_util",
242+
srcs = ["allocate_tensor_util.cpp"],
243+
exported_headers = ["allocate_tensor_util.cpp"],
244+
deps = [
245+
"//executorch/runtime/kernel:kernel_includes",
246+
],
247+
visibility = ["//executorch/kernels/portable/cpu/..."],
248+
)
249+
240250
# Utility functions that can be used by operators that perform reduction
241251
for aten_mode in [True, False]:
242252
suffix = "_aten" if aten_mode else ""
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <executorch/kernels/portable/cpu/util/allocate_tensor_util.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14+
#include <executorch/runtime/kernel/kernel_includes.h>
15+
#include <executorch/runtime/platform/runtime.h>
16+
#include <executorch/test/utils/DeathTest.h>
17+
using ScalarType = exec_aten::ScalarType;
18+
19+
class AllocateTest : public ::testing::Test {
20+
protected:
21+
void SetUp() override {
22+
// Since these tests cause ET_LOG to be called, the PAL must be initialized
23+
// first.
24+
torch::executor::runtime_init();
25+
}
26+
};
27+
28+
TEST(AllocateTest, AllocateTensor) {
29+
uint8_t* temp_allocator_ptr = (uint8_t*)malloc(2048);
30+
executorch::runtime::MemoryAllocator temp_allocator(2048, temp_allocator_ptr);
31+
executorch::runtime::KernelRuntimeContext ctx(nullptr, &temp_allocator);
32+
33+
executorch::aten::SizesType sizes[3] = {1, 2, 3};
34+
executorch::aten::DimOrderType dim_order[3] = {0, 1, 2};
35+
executorch::aten::StridesType strides[3] = {3, 3, 1};
36+
37+
torch::executor::ArrayRef<exec_aten::SizesType> sizes_ref(sizes, 3);
38+
torch::executor::ArrayRef<exec_aten::StridesType> strides_ref(strides, 3);
39+
torch::executor::ArrayRef<exec_aten::DimOrderType> dim_orders_ref(
40+
dim_order, 3);
41+
42+
torch::executor::allocate_tensor(
43+
ctx, sizes, dim_order, strides, ScalarType::Float);
44+
45+
free(temp_allocator_ptr);
46+
}
47+
48+
TEST(AllocateTest, FailAllocateTensor) {
49+
torch::executor::runtime_init();
50+
51+
uint8_t* temp_allocator_ptr = (uint8_t*)malloc(16);
52+
executorch::runtime::MemoryAllocator temp_allocator(16, temp_allocator_ptr);
53+
executorch::runtime::KernelRuntimeContext ctx(nullptr, &temp_allocator);
54+
55+
executorch::aten::SizesType sizes[3] = {1, 2, 3};
56+
executorch::aten::DimOrderType dim_order[3] = {0, 1, 2};
57+
executorch::aten::StridesType strides[3] = {3, 3, 1};
58+
59+
torch::executor::ArrayRef<exec_aten::SizesType> sizes_ref(sizes, 3);
60+
torch::executor::ArrayRef<exec_aten::StridesType> strides_ref(strides, 3);
61+
torch::executor::ArrayRef<exec_aten::DimOrderType> dim_orders_ref(
62+
dim_order, 3);
63+
64+
ET_EXPECT_DEATH(
65+
torch::executor::allocate_tensor(
66+
ctx, sizes, dim_order, strides, ScalarType::Float),
67+
"Failed to malloc");
68+
69+
free(temp_allocator_ptr);
70+
}

kernels/portable/cpu/util/test/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,13 @@ def define_common_targets():
2121
"//executorch/kernels/portable/cpu/util:reduce_util",
2222
],
2323
)
24+
25+
runtime.cxx_test(
26+
name = "allocate_tensor_test",
27+
srcs = ["allocate_tensor_test.cpp"],
28+
deps = [
29+
"//executorch/runtime/core/exec_aten:lib",
30+
"//executorch/kernels/portable/cpu/util:allocate_tensor_util",
31+
"//executorch/runtime/kernel:kernel_includes",
32+
],
33+
)

0 commit comments

Comments
 (0)