Skip to content

Commit 094d00e

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
set_output_data_ptr api (#223)
Summary: Pull Request resolved: #223 People really shouldnt be using get_output and mutating the structure, this provides a way to set the output data ptr in a more controlled manner Reviewed By: iseeyuan Differential Revision: D49029435 fbshipit-source-id: 44f527d99a0d2c50bbe5a022757adcbd4f7ae20f
1 parent 64d451f commit 094d00e

File tree

10 files changed

+256
-28
lines changed

10 files changed

+256
-28
lines changed

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,12 @@ __ET_NODISCARD Error copy_tensor_data(
910910
const exec_aten::Tensor& t_dst,
911911
const exec_aten::Tensor& t_src);
912912

913+
/**
914+
* Set the data_ptr of t to buffer.
915+
*/
916+
__ET_NODISCARD Error
917+
set_tensor_data(const exec_aten::Tensor& t, void* buffer, size_t buffer_size);
918+
913919
/**
914920
* Reset tensor's data_ptr, clear all the storage for at::Tensor.
915921
*/

runtime/core/exec_aten/util/tensor_util_aten.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,19 @@ Error copy_tensor_data(const at::Tensor& t_dst, const at::Tensor& t_src) {
133133
return Error::Ok;
134134
}
135135

136+
__ET_NODISCARD Error
137+
set_tensor_data(const at::Tensor& t, void* buffer, size_t buffer_size) {
138+
ET_CHECK_OR_RETURN_ERROR(
139+
buffer_size >= t.nbytes(),
140+
InvalidArgument,
141+
"buffer_size %zu is smaller than smaller than tensor nbytes %zu",
142+
buffer_size,
143+
t.nbytes());
144+
t.unsafeGetTensorImpl()->unsafe_storage().set_data_ptr(
145+
at::DataPtr(buffer, DeviceType::CPU));
146+
return Error::Ok;
147+
}
148+
136149
void reset_data_ptr(const at::Tensor& tensor) {
137150
auto impl = tensor.unsafeGetTensorImpl();
138151
impl->set_sizes_contiguous(0);

runtime/core/exec_aten/util/tensor_util_portable.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,20 @@ Error copy_tensor_data(
116116
return Error::Ok;
117117
}
118118

119+
__ET_NODISCARD Error set_tensor_data(
120+
const torch::executor::Tensor& t,
121+
void* buffer,
122+
size_t buffer_size) {
123+
ET_CHECK_OR_RETURN_ERROR(
124+
buffer_size >= t.nbytes(),
125+
InvalidArgument,
126+
"buffer_size %zu is smaller than smaller than tensor nbytes %zu",
127+
buffer_size,
128+
t.nbytes());
129+
t.unsafeGetTensorImpl()->set_data(buffer);
130+
return Error::Ok;
131+
}
132+
119133
void reset_data_ptr(const torch::executor::Tensor& tensor) {
120134
// Lean mode doesn't deallocate the tensor data_ptr in the allocator
121135
tensor.set_data(nullptr);

runtime/executor/method.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,17 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
632632
}
633633
}
634634

635+
pre_allocated_output_ = false;
636+
637+
// Get pre_allocation info for output tensors
638+
for (int i = 0; i < outputs_size(); i++) {
639+
if (get_output(i).isTensor()) {
640+
pre_allocated_output_ =
641+
get_output(i).toTensor().const_data_ptr() != nullptr;
642+
break;
643+
}
644+
}
645+
635646
ET_CHECK_OR_RETURN_ERROR(
636647
n_chains_ > 0,
637648
Internal,
@@ -799,6 +810,51 @@ Method::set_inputs(const exec_aten::ArrayRef<EValue>& input_evalues) {
799810
return Error::Ok;
800811
}
801812

813+
__ET_NODISCARD Error
814+
Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) {
815+
// Check method state
816+
ET_CHECK_OR_RETURN_ERROR(
817+
initialized(),
818+
InvalidState,
819+
"Outputs can not be retrieved until method has been initialized.");
820+
821+
ET_CHECK_OR_RETURN_ERROR(
822+
!pre_allocated_output_,
823+
InvalidState,
824+
"Overriding output data pointer allocated by memory plan is not allowed.");
825+
826+
// Check the args
827+
ET_CHECK_OR_RETURN_ERROR(
828+
output_idx <= outputs_size(),
829+
InvalidArgument,
830+
"output_idx: %zu num_outputs: %zu",
831+
output_idx,
832+
outputs_size());
833+
834+
auto& output = mutable_output(output_idx);
835+
ET_CHECK_OR_RETURN_ERROR(
836+
output.isTensor(),
837+
InvalidArgument,
838+
"output type: %zu is not tensor",
839+
(size_t)output.tag);
840+
841+
auto& t = output.toTensor();
842+
ET_CHECK_OR_RETURN_ERROR(
843+
output.isTensor(),
844+
InvalidArgument,
845+
"output type: %zu is not tensor",
846+
(size_t)output.tag);
847+
ET_CHECK_OR_RETURN_ERROR(
848+
t.nbytes() <= size,
849+
InvalidArgument,
850+
"buffer size: %zu is smaller then expected tensor size: %zu",
851+
size,
852+
t.nbytes());
853+
854+
// Set data
855+
return internal::set_tensor_data(t, buffer, size);
856+
}
857+
802858
__ET_NODISCARD Error
803859
Method::get_outputs(EValue* output_evalues, size_t length) {
804860
ET_CHECK_OR_RETURN_ERROR(

runtime/executor/method.h

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ class Method final {
6363
n_chains_(0),
6464
chains_(nullptr),
6565
init_state_(InitializationState::Uninitialized),
66-
pre_allocated_input_(false) {}
66+
pre_allocated_input_(false),
67+
pre_allocated_output_(false) {}
6768

6869
/**
6970
* Move ctor. Takes ownership of resources previously owned by `rhs`,
@@ -82,7 +83,8 @@ class Method final {
8283
n_chains_(rhs.n_chains_),
8384
chains_(rhs.chains_),
8485
init_state_(rhs.init_state_),
85-
pre_allocated_input_(rhs.pre_allocated_input_) {
86+
pre_allocated_input_(rhs.pre_allocated_input_),
87+
pre_allocated_output_(rhs.pre_allocated_output_) {
8688
// Required: clear out fields that the dtor looks at, so that we don't free
8789
// anything twice.
8890
rhs.n_value_ = 0;
@@ -97,10 +99,11 @@ class Method final {
9799
rhs.program_ = nullptr;
98100
rhs.memory_manager_ = nullptr;
99101
rhs.serialization_plan_ = nullptr;
102+
rhs.event_tracer_ = nullptr;
100103
rhs.n_chains_ = 0;
101104
rhs.chains_ = nullptr;
102105
rhs.pre_allocated_input_ = false;
103-
rhs.event_tracer_ = nullptr;
106+
rhs.pre_allocated_output_ = false;
104107
}
105108

106109
/**
@@ -144,6 +147,28 @@ class Method final {
144147
__ET_NODISCARD Error
145148
set_inputs(const exec_aten::ArrayRef<EValue>& input_evalues);
146149

150+
/**
151+
* Sets the data buffer of the specified method output to the provided value.
152+
*
153+
* NOTE: Based on the memory plan of the method, the output tensors may not
154+
* have buffer space pre-allocated for them, in this case the executor will
155+
* point those tensors to the buffer provided here, so the user should take
156+
* care that the life span of this memory outlasts the executor forward.
157+
*
158+
* @param[in] buffer The block of memory to point the specified tensor at.
159+
*
160+
* @param[in] size the length of buffer in bytes, must be >= the nbytes of the
161+
* specified tensor.
162+
*
163+
* @param[in] output_idx The index of the output to set the data_ptr for. Must
164+
* correspond to a tensor, and that tensor must not have had a buffer
165+
* allocated by the memory plan.
166+
*
167+
* @returns Error::Ok on success, non-Ok on failure.
168+
*/
169+
__ET_NODISCARD Error
170+
set_output_data_ptr(void* buffer, size_t size, size_t output_idx);
171+
147172
/**
148173
* Copies the method's outputs into the provided array.
149174
*
@@ -263,6 +288,7 @@ class Method final {
263288

264289
InitializationState init_state_;
265290
bool pre_allocated_input_;
291+
bool pre_allocated_output_;
266292

267293
/**
268294
* Parses the elements of the values_ array. On error, n_value_ will be set to

runtime/executor/test/method_test.cpp

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,46 +33,41 @@ constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
3333

3434
class MethodTest : public ::testing::Test {
3535
protected:
36-
void SetUp() override {
37-
// Create a loader for the serialized ModuleAdd program.
38-
const char* path = std::getenv("ET_MODULE_ADD_PATH");
36+
void load_program(const char* path, const char* module_name) {
37+
// Create a loader for the serialized program.
3938
Result<FileDataLoader> loader = FileDataLoader::From(path);
4039
ASSERT_EQ(loader.error(), Error::Ok);
41-
add_loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
40+
loaders_.insert(
41+
{module_name,
42+
std::make_unique<FileDataLoader>(std::move(loader.get()))});
4243

4344
// Use it to load the program.
4445
Result<Program> program = Program::Load(
45-
add_loader_.get(), Program::Verification::InternalConsistency);
46+
loaders_[module_name].get(),
47+
Program::Verification::InternalConsistency);
4648
ASSERT_EQ(program.error(), Error::Ok);
47-
add_program_ = std::make_unique<Program>(std::move(program.get()));
48-
49-
// Create a loader for the serialized ModuleIndex program.
50-
const char* index_path = std::getenv("ET_MODULE_INDEX_PATH");
51-
Result<FileDataLoader> index_loader = FileDataLoader::From(index_path);
52-
ASSERT_EQ(index_loader.error(), Error::Ok);
53-
index_loader_ =
54-
std::make_unique<FileDataLoader>(std::move(index_loader.get()));
49+
programs_.insert(
50+
{module_name, std::make_unique<Program>(std::move(program.get()))});
51+
}
5552

56-
// Use it to load the program.
57-
Result<Program> index_program = Program::Load(
58-
index_loader_.get(), Program::Verification::InternalConsistency);
59-
ASSERT_EQ(index_program.error(), Error::Ok);
60-
index_program_ = std::make_unique<Program>(std::move(index_program.get()));
53+
void SetUp() override {
54+
load_program(std::getenv("ET_MODULE_ADD_PATH"), "add");
55+
load_program(std::getenv("ET_MODULE_INDEX_PATH"), "index");
56+
load_program(
57+
std::getenv("ET_MODULE_DYNAMIC_CAT_UNALLOCATED_IO_PATH"), "cat");
6158
}
6259

6360
private:
6461
// Must outlive program_, but tests shouldn't need to touch it.
65-
std::unique_ptr<FileDataLoader> add_loader_;
66-
std::unique_ptr<FileDataLoader> index_loader_;
62+
std::unordered_map<std::string, std::unique_ptr<FileDataLoader>> loaders_;
6763

6864
protected:
69-
std::unique_ptr<Program> add_program_;
70-
std::unique_ptr<Program> index_program_;
65+
std::unordered_map<std::string, std::unique_ptr<Program>> programs_;
7166
};
7267

7368
TEST_F(MethodTest, MoveTest) {
7469
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
75-
Result<Method> method = add_program_->load_method("forward", &mmm.get());
70+
Result<Method> method = programs_["add"]->load_method("forward", &mmm.get());
7671
ASSERT_EQ(method.error(), Error::Ok);
7772

7873
// Can execute the method.
@@ -97,7 +92,7 @@ TEST_F(MethodTest, MoveTest) {
9792

9893
TEST_F(MethodTest, SetPrimInputTest) {
9994
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
100-
Result<Method> method = add_program_->load_method("forward", &mmm.get());
95+
Result<Method> method = programs_["add"]->load_method("forward", &mmm.get());
10196
ASSERT_EQ(method.error(), Error::Ok);
10297

10398
// Can execute the method.
@@ -121,6 +116,75 @@ TEST_F(MethodTest, SetPrimInputTest) {
121116
torch::executor::util::FreeInputs(inputs);
122117
}
123118

119+
TEST_F(MethodTest, AliasedIOTest) {
120+
// TODO(T163238401)
121+
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
122+
Result<Method> method = programs_["cat"]->load_method("forward", &mmm.get());
123+
ASSERT_EQ(method.error(), Error::Ok);
124+
125+
// Set up io. Input and Output should share the same memory.
126+
constexpr int buffer_size = 16;
127+
float buffer[buffer_size]; // Initial input is (2,4) we then cat a (1,4) to it
128+
// twice for a final shape of (4,4)
129+
for (int i = 0; i < buffer_size; ++i) {
130+
buffer[i] = 0.f;
131+
}
132+
int32_t sizes[2] = {2, 4};
133+
uint8_t dim_order[2] = {0, 1};
134+
int32_t strides[2] = {4, 1};
135+
torch::executor::TensorImpl impl(
136+
torch::executor::ScalarType::Float, 2, sizes, buffer, dim_order, strides);
137+
138+
auto input_err = method->set_input(EValue(torch::executor::Tensor(&impl)), 0);
139+
ASSERT_EQ(input_err, Error::Ok);
140+
141+
auto output_err = method->set_output_data_ptr(buffer, sizeof(buffer), 0);
142+
ASSERT_EQ(output_err, Error::Ok);
143+
ASSERT_EQ(method->get_output(0).toTensor().const_data_ptr(), buffer);
144+
145+
// Execute the method once. Cat a 1x4 to a 2x4.
146+
auto execute_error = method->execute();
147+
ASSERT_EQ(execute_error, Error::Ok);
148+
149+
auto output = method->get_output(0);
150+
ASSERT_TRUE(output.isTensor());
151+
EXPECT_EQ(output.toTensor().sizes()[0], 3);
152+
EXPECT_EQ(output.toTensor().sizes()[1], 4);
153+
// Original input should be 0.
154+
for (size_t i = 0; i < 2 * 4; i++) {
155+
EXPECT_FLOAT_EQ(output.toTensor().const_data_ptr<float>()[i], 0.f);
156+
}
157+
// Section that was cat on should be 1.
158+
for (size_t i = 0; i < 1 * 4; i++) {
159+
EXPECT_FLOAT_EQ(
160+
output.toTensor().const_data_ptr<float>()[(2 * 4) + i], 1.f);
161+
}
162+
163+
// Set the input again to update the size.
164+
sizes[0] = output.toTensor().sizes()[0];
165+
torch::executor::TensorImpl impl_2(
166+
torch::executor::ScalarType::Float, 2, sizes, buffer, dim_order, strides);
167+
input_err = method->set_input(EValue(torch::executor::Tensor(&impl_2)), 0);
168+
ASSERT_EQ(input_err, Error::Ok);
169+
170+
// Execute the method again. Cat a 1x4 to a 3x4.
171+
execute_error = method->execute();
172+
ASSERT_EQ(execute_error, Error::Ok);
173+
174+
output = method->get_output(0);
175+
EXPECT_EQ(output.toTensor().sizes()[0], 4);
176+
EXPECT_EQ(output.toTensor().sizes()[1], 4);
177+
// Original input should be 0.
178+
for (size_t i = 0; i < 2 * 4; i++) {
179+
EXPECT_FLOAT_EQ(output.toTensor().const_data_ptr<float>()[i], 0.f);
180+
}
181+
// Previous section and the new one that were cat on should be 1.
182+
for (size_t i = 0; i < 2 * 4; i++) {
183+
EXPECT_FLOAT_EQ(
184+
output.toTensor().const_data_ptr<float>()[(2 * 4) + i], 1.f);
185+
}
186+
}
187+
124188
// TODO(T161163608): Test is disabled due to a resize bug in tensor_index_out of
125189
// the portable op lib
126190

runtime/executor/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def define_common_targets(is_fbcode = False):
103103
# an fbcode target path because the authoring/export tools
104104
# intentionally don't work in xplat (since they're host-only tools).
105105
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
106+
"ET_MODULE_DYNAMIC_CAT_UNALLOCATED_IO_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleDynamicCatUnallocatedIO.pte])",
106107
"ET_MODULE_INDEX_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleIndex.pte])",
107108
"ET_MODULE_MULTI_ENTRY_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleMultipleEntry.pte])",
108109
}

test/end2end/exported_module.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,26 @@ def return_wrapper():
130130
for method in methods:
131131
method_name_to_args[method] = trace_inputs
132132

133+
method_name_to_constraints = None
134+
if hasattr(eager_module, "get_constraints"):
135+
assert capture_config is not None
136+
assert capture_config.enable_aot is True
137+
trace_constraints = eager_module.get_constraints()
138+
method_name_to_constraints = {}
139+
for method in methods:
140+
method_name_to_constraints[method] = trace_constraints
141+
142+
memory_planning_pass = MemoryPlanningPass("greedy")
143+
if hasattr(eager_module, "get_memory_planning_pass"):
144+
memory_planning_pass = eager_module.get_memory_planning_pass()
145+
133146
# Capture an executorch program.
134147
executorch_program = (
135148
exir.capture_multiple(
136149
eager_module,
137150
method_name_to_args,
138151
capture_config,
152+
constraints=method_name_to_constraints,
139153
)
140154
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
141155
.to_executorch(
@@ -150,7 +164,7 @@ def return_wrapper():
150164
to_scratch_op_pass,
151165
],
152166
dynamic_memory_planning_mode=dynamic_memory_planning_mode,
153-
memory_planning_pass=MemoryPlanningPass("greedy"),
167+
memory_planning_pass=memory_planning_pass,
154168
to_out_var_pass=ToOutVarPass(ignore_to_out_var_failure),
155169
)
156170
)

0 commit comments

Comments
 (0)