Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions runtime/executor/method_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Result<Tag> get_tag(
}
}

size_t calculate_nbytes(
Result<size_t> calculate_nbytes(
Span<const int32_t> sizes,
executorch::aten::ScalarType scalar_type) {
size_t n = 1;
Expand All @@ -61,7 +61,13 @@ size_t calculate_nbytes(
prev_n = n;
n *= sizes[i];
// Check for overflow
ET_CHECK(sizes[i] == 0 || n / sizes[i] == prev_n);
ET_CHECK_OR_RETURN_ERROR(
sizes[i] == 0 || n / sizes[i] == prev_n,
InvalidArgument,
"Invalid size[%zu]: %d. Potentially overflowed, expect to be 0 or prev_n: %zu",
i,
sizes[i],
prev_n);
}

size_t elem_size = executorch::runtime::elementSize(scalar_type);
Expand All @@ -70,25 +76,47 @@ size_t calculate_nbytes(
n = n * elem_size;

// Check for overflow
ET_CHECK(elem_size == 0 || n / elem_size == prev_n);
ET_CHECK_OR_RETURN_ERROR(
elem_size == 0 || n / elem_size == prev_n,
InvalidArgument,
"Invalid elem_size: %zu. Potentially overflowed, expect to be 0 or prev_n: %zu",
elem_size,
prev_n);

return n;
}

} // namespace

/*static*/ Result<TensorInfo> TensorInfo::create(
Span<const int32_t> sizes,
Span<const uint8_t> dim_order,
executorch::aten::ScalarType scalar_type,
const bool is_memory_planned,
std::string_view name) {
auto nbytes = calculate_nbytes(sizes, scalar_type);
ET_CHECK_OR_RETURN_ERROR(
nbytes.ok(),
InvalidArgument,
"Failed to calculate nbytes for TensorInfo");

return TensorInfo(
sizes, dim_order, scalar_type, is_memory_planned, name, nbytes.get());
}

TensorInfo::TensorInfo(
Span<const int32_t> sizes,
Span<const uint8_t> dim_order,
executorch::aten::ScalarType scalar_type,
const bool is_memory_planned,
std::string_view name)
std::string_view name,
size_t nbytes)
: sizes_(sizes),
dim_order_(dim_order),
name_(name),
scalar_type_(scalar_type),
is_memory_planned_(is_memory_planned),
nbytes_(calculate_nbytes(sizes_, scalar_type_)) {}
nbytes_(nbytes) {}

Span<const int32_t> TensorInfo::sizes() const {
return sizes_;
Expand Down Expand Up @@ -160,7 +188,7 @@ Result<TensorInfo> MethodMeta::input_tensor_meta(size_t index) const {
auto input_index = s_plan_->inputs()->Get(index);
// input_index was already validated by input_tag().
auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor();
return TensorInfo(
return TensorInfo::create(
Span<const int32_t>(
tensor_value->sizes()->data(), tensor_value->sizes()->size()),
Span<const uint8_t>(
Expand Down Expand Up @@ -212,7 +240,7 @@ Result<TensorInfo> MethodMeta::output_tensor_meta(size_t index) const {
// output_index was already validated by output_tag().
auto tensor_value = s_plan_->values()->Get(output_index)->val_as_Tensor();

return TensorInfo(
return TensorInfo::create(
Span<const int32_t>(
tensor_value->sizes()->data(), tensor_value->sizes()->size()),
Span<const uint8_t>(
Expand Down Expand Up @@ -255,7 +283,7 @@ Result<TensorInfo> MethodMeta::attribute_tensor_meta(size_t index) const {
auto t_name =
tensor_value->extra_tensor_info()->fully_qualified_name();
// Count constant returns as memory planned
return TensorInfo(
return TensorInfo::create(
Span<const int32_t>(
tensor_value->sizes()->data(), tensor_value->sizes()->size()),
Span<const uint8_t>(
Expand Down
21 changes: 20 additions & 1 deletion runtime/executor/method_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,32 @@ class TensorInfo final {
friend class MethodMeta;
friend class testing::TensorInfoTestFriend;

TensorInfo(
/**
* Create a TensorInfo instance.
*
* @param[in] sizes The sizes of the tensor.
* @param[in] dim_order The dim order of the tensor.
* @param[in] scalar_type The scalar type of the tensor.
* @param[in] is_memory_planned Whether the tensor's memory was planned.
* @param[in] name The fully qualified name of the tensor.
* @returns A Result containing the TensorInfo on success, or an error on
* failure.
*/
static Result<TensorInfo> create(
Span<const int32_t> sizes,
Span<const uint8_t> dim_order,
executorch::aten::ScalarType scalar_type,
const bool is_memory_planned,
std::string_view name);

TensorInfo(
Span<const int32_t> sizes,
Span<const uint8_t> dim_order,
executorch::aten::ScalarType scalar_type,
const bool is_memory_planned,
std::string_view name,
size_t nbytes);

/**
* The sizes of the tensor.
*
Expand Down
13 changes: 7 additions & 6 deletions runtime/executor/test/method_meta_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ class TensorInfoTestFriend final {
executorch::aten::ScalarType scalar_type,
const bool is_memory_planned,
executorch::aten::string_view name) {
return TensorInfo(
Span<const int32_t>(sizes.data(), sizes.size()),
Span<const uint8_t>(dim_order.data(), dim_order.size()),
scalar_type,
is_memory_planned,
name);
return TensorInfo::create(
Span<const int32_t>(sizes.data(), sizes.size()),
Span<const uint8_t>(dim_order.data(), dim_order.size()),
scalar_type,
is_memory_planned,
name)
.get();
}
};
} // namespace testing
Expand Down
Loading