Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/xnnpack/test/runtime/test_xnn_data_separation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,21 @@ TEST_F(DataSeparationTest, TestE2E) {
"forward", &mmm.get(), nullptr, linear_data_map_.get());
ASSERT_EQ(method.error(), Error::Ok);

// Set a dummy input.
int32_t sizes[1] = {3};
uint8_t dim_order[1] = {0};
int32_t strides[1] = {1};
executorch::aten::TensorImpl impl(
executorch::aten::ScalarType::Float,
1,
sizes,
nullptr,
dim_order,
strides);
auto input_err = method->set_input(
executorch::runtime::EValue(executorch::aten::Tensor(&impl)), 0);
ASSERT_EQ(input_err, Error::Ok);

// Can execute the method.
Error err = method->execute();
ASSERT_EQ(err, Error::Ok);
Expand Down
39 changes: 7 additions & 32 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ runtime::Error Module::load_method(
method_holder.memory_manager.get(),
event_tracer ? event_tracer : this->event_tracer(),
data_map_.get()));
method_holder.inputs.resize(method_holder.method->inputs_size());
methods_.emplace(method_name, std::move(method_holder));
}
return runtime::Error::Ok;
Expand All @@ -233,28 +232,10 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
const std::vector<runtime::EValue>& input_values) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
auto& method = methods_.at(method_name).method;
auto& inputs = methods_.at(method_name).inputs;

ET_CHECK_OR_RETURN_ERROR(
input_values.size() <= inputs.size(),
InvalidArgument,
"input size: %zu does not match method input size: %zu",
input_values.size(),
inputs.size());
for (size_t i = 0; i < input_values.size(); ++i) {
if (!input_values[i].isNone()) {
inputs[i] = input_values[i];
}
for (auto index = 0; index < input_values.size(); ++index) {
ET_CHECK_OK_OR_RETURN_ERROR(method->set_input(input_values[index], index));
}
for (size_t i = 0; i < inputs.size(); ++i) {
ET_CHECK_OR_RETURN_ERROR(
!inputs[i].isNone(), InvalidArgument, "input %zu is none", i);
}
ET_CHECK_OK_OR_RETURN_ERROR(
method->set_inputs(executorch::aten::ArrayRef<runtime::EValue>(
inputs.data(), inputs.size())));
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());

const auto outputs_size = method->outputs_size();
std::vector<runtime::EValue> outputs(outputs_size);
ET_CHECK_OK_OR_RETURN_ERROR(
Expand All @@ -268,23 +249,17 @@ runtime::Error Module::set_input(
const runtime::EValue& input_value,
size_t input_index) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
methods_.at(method_name).inputs.at(input_index) = input_value;
return runtime::Error::Ok;
auto& method = methods_.at(method_name).method;
return method->set_input(input_value, input_index);
}

runtime::Error Module::set_inputs(
const std::string& method_name,
const std::vector<runtime::EValue>& input_values) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
auto& inputs = methods_.at(method_name).inputs;
ET_CHECK_OR_RETURN_ERROR(
inputs.size() == input_values.size(),
InvalidArgument,
"input size: %zu does not match method input size: %zu",
input_values.size(),
inputs.size());
inputs = input_values;
return runtime::Error::Ok;
auto& method = methods_.at(method_name).method;
return method->set_inputs(executorch::aten::ArrayRef<runtime::EValue>(
input_values.data(), input_values.size()));
}

runtime::Error Module::set_output(
Expand Down
1 change: 0 additions & 1 deletion extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,6 @@ class Module {
std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
std::unique_ptr<runtime::MemoryManager> memory_manager;
std::unique_ptr<Method> method;
std::vector<runtime::EValue> inputs;
};

std::string file_path_;
Expand Down
2 changes: 1 addition & 1 deletion extension/module/test/module_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ TEST_F(ModuleTest, TestForward) {
EXPECT_TENSOR_CLOSE(result->at(0).toTensor(), *expected.get());

auto tensor2 = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 5.f});
const auto result2 = module->forward({tensor2, tensor2});
const auto result2 = module->forward({tensor2, tensor2, 1.0});
EXPECT_EQ(result2.error(), Error::Ok);

const auto expected2 = make_tensor_ptr({2, 2}, {4.f, 6.f, 8.f, 10.f});
Expand Down
2 changes: 2 additions & 0 deletions extension/runner_util/test/inputs_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class InputsTest : public ::testing::Test {
TEST_F(InputsTest, Smoke) {
Result<BufferCleanup> input_buffers = prepare_input_tensors(*method_);
ASSERT_EQ(input_buffers.error(), Error::Ok);
auto input_err = method_->set_input(executorch::runtime::EValue(1.0), 2);
ASSERT_EQ(input_err, Error::Ok);

// We can't look at the input tensors, but we can check that the outputs make
// sense after executing the method.
Expand Down
73 changes: 25 additions & 48 deletions runtime/core/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,42 +205,37 @@ using ::executorch::runtime::error_code_t;
* @param[in] ... Optional format string for the log error message and its
* arguments.
*/
#define ET_CHECK_OK_OR_RETURN_ERROR(error__, ...) \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR(error__, ##__VA_ARGS__)

// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \
__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) \
(__VA_ARGS__)
#define ET_CHECK_OK_OR_RETURN_ERROR(...) \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR(__VA_ARGS__)

/**
* Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
* This macro selects the correct version of
* ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR based on the number of arguments passed.
* It uses a trick with the preprocessor to count the number of arguments and
* then selects the appropriate macro.
*
* The macro expansion uses __VA_ARGS__ to accept any number of arguments and
* then appends them to ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_, followed by the
* count of arguments. The count is determined by the macro
* ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT which takes the arguments and
* passes them along with a sequence of numbers (2, 1). The preprocessor then
* matches this sequence to the correct number of arguments provided.
*
* If two arguments are passed, ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 is
* selected, suitable for cases where an error code and a custom message are
* provided. If only one argument is passed,
* ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1 is selected, which is used for cases
* with just an error code.
*
* Usage:
* ET_CHECK_OK_OR_RETURN_ERROR(error_code); // Calls v1
* ET_CHECK_OK_OR_RETURN_ERROR(error_code, "Error message", ...); // Calls v2
* It uses a helper that reliably picks the 1-arg or 2+-arg form on
* MSVC/Clang/GCC.
*/
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_##N
#define ET_INTERNAL_EXPAND(x) x
#define ET_INTERNAL_GET_MACRO( \
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, NAME, ...) \
NAME

// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
// Picks _2 for 2..10 args, _1 for exactly 1 arg.
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \
ET_INTERNAL_EXPAND(ET_INTERNAL_GET_MACRO( \
__VA_ARGS__, \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 10 */ \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 9 */ \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 8 */ \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 7 */ \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 6 */ \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 5 */ \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 4 */ \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 3 */ \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2, /* 2 */ \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1 /* 1 */ \
)(__VA_ARGS__))

// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \
Expand All @@ -260,21 +255,3 @@ using ::executorch::runtime::error_code_t;
return et_error__; \
} \
} while (0)

// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \
ET_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
97 changes: 50 additions & 47 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,22 @@ Error Method::parse_values(const NamedDataMap* external_data_map) {
auto flatbuffer_values = serialization_plan_->values();
ET_CHECK_OR_RETURN_ERROR(
flatbuffer_values != nullptr, InvalidProgram, "Missing values");
size_t n_value = flatbuffer_values->size();
const size_t n_value = flatbuffer_values->size();
values_ = memory_manager_->method_allocator()->allocateList<EValue>(n_value);
if (values_ == nullptr) {
return Error::MemoryAllocationFailed;
}
const size_t n_input = inputs_size();
if (n_input > 0) {
input_set_ =
memory_manager_->method_allocator()->allocateList<bool>(n_input);
if (input_set_ == nullptr) {
return Error::MemoryAllocationFailed;
}
for (size_t i = 0; i < n_input; ++i) {
input_set_[i] = false;
}
}

// Count the number of tensors marked as EXTERNAL for this method. The actual
// number of external constants may be smaller, eg. if multiple tensors point
Expand Down Expand Up @@ -1076,26 +1087,22 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {
executorch::runtime::toString(t_src.scalar_type()));
// Reset the shape for the Method's input as the size of forwarded input
// tensor for shape dynamism. Also is a safety check if need memcpy.
Error err = resize_tensor(t_dst, t_src.sizes());
ET_CHECK_OR_RETURN_ERROR(
err == Error::Ok,
InvalidArgument,
"Error setting input %" ET_PRIsize_t ": 0x%" PRIx32,
input_idx,
static_cast<uint32_t>(err));
Error error;
ET_CHECK_OK_OR_RETURN_ERROR(
resize_tensor(t_dst, t_src.sizes()),
"Error resizing tensor at input %" ET_PRIsize_t,
input_idx);
auto tensor_meta = this->method_meta().input_tensor_meta(input_idx);
if (tensor_meta->is_memory_planned()) {
error = internal::copy_tensor_data(t_dst, t_src);
ET_CHECK_OK_OR_RETURN_ERROR(
internal::copy_tensor_data(t_dst, t_src),
"Error copying tensor data at input %" ET_PRIsize_t,
input_idx);
} else {
error = internal::share_tensor_data(t_dst, t_src);
ET_CHECK_OK_OR_RETURN_ERROR(
internal::share_tensor_data(t_dst, t_src),
"Error sharing tensor data at input %" ET_PRIsize_t,
input_idx);
}
ET_CHECK_OR_RETURN_ERROR(
error == Error::Ok,
InvalidArgument,
"Error setting data_ptr %" ET_PRIsize_t ": 0x%" PRIx32,
input_idx,
static_cast<uint32_t>(error));
// Prims have to be the same as what was traced
} else if (e.isInt()) {
ET_CHECK_OR_RETURN_ERROR(
Expand Down Expand Up @@ -1163,35 +1170,16 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {

return Error::InvalidArgument;
}
input_set_[input_idx] = true;

return Error::Ok;
}

ET_NODISCARD Error
Method::set_inputs(const executorch::aten::ArrayRef<EValue>& input_evalues) {
ET_CHECK_OR_RETURN_ERROR(
initialized(),
InvalidState,
"Inputs can not be set until method has been initialized.");

ET_CHECK_OR_RETURN_ERROR(
step_state_.instr_idx == 0 && step_state_.chain_idx == 0,
InvalidState,
"Inputs can not be set mid execution.");

size_t input_size = inputs_size();
ET_CHECK_OR_RETURN_ERROR(
input_size == input_evalues.size(),
InvalidArgument,
"The length of given input array (%" ET_PRIsize_t
") must be same as the number of inputs in method (%" ET_PRIsize_t ").",
input_evalues.size(),
input_size);

for (size_t i = 0; i < input_size; i++) {
Error status = set_input(input_evalues[i], i);
if (status != Error::Ok) {
return status;
}
const size_t n_input = inputs_size();
for (size_t i = 0; i < n_input; ++i) {
ET_CHECK_OK_OR_RETURN_ERROR(set_input(input_evalues[i], i));
}
return Error::Ok;
}
Expand Down Expand Up @@ -1284,20 +1272,21 @@ ET_NODISCARD Error Method::get_inputs(EValue* input_evalues, size_t length) {
initialized(),
InvalidState,
"Inputs can not be retrieved until method has been initialized.");

const size_t n_input = inputs_size();
ET_CHECK_OR_RETURN_ERROR(
length >= inputs_size(),
length >= n_input,
InvalidArgument,
"The given array is not large enough to hold all inputs.");

for (size_t i = 0; i < inputs_size(); i++) {
for (size_t i = 0; i < n_input; ++i) {
input_evalues[i] = values_[get_input_index(i)];
// Accessing inputs this way is deprecated.
// We assume the users to be responsible to set the inputs they get.
input_set_[i] = true;
}

for (size_t i = inputs_size(); i < length; i++) {
for (size_t i = n_input; i < length; ++i) {
input_evalues[i] = EValue();
}

return Error::Ok;
}

Expand Down Expand Up @@ -1545,6 +1534,14 @@ Error Method::execute() {
initialized(),
NotSupported,
"Cannot execute until method has been initialized.");
const size_t n_input = inputs_size();
for (size_t i = 0; i < n_input; ++i) {
ET_CHECK_OR_RETURN_ERROR(
input_set_[i],
InvalidArgument,
"Input %" ET_PRIsize_t " has not been set.",
i);
}
ET_LOG(Debug, "Executing method: %s.", method_meta().name());

// Chains are executed sequentially today, but future async designs may
Expand Down Expand Up @@ -1622,10 +1619,16 @@ size_t Method::get_input_index(size_t i) const {
}

const EValue& Method::get_input(size_t i) const {
// Accessing inputs this way is deprecated.
// We assume the users to be responsible to set the inputs they get.
input_set_[i] = true;
return get_value(get_input_index(i));
}

EValue& Method::mutable_input(size_t i) {
// Accessing inputs this way is deprecated.
// We assume the users to be responsible to set the inputs they get.
input_set_[i] = true;
return mutable_value(get_input_index(i));
}

Expand Down
Loading
Loading