Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 75 additions & 99 deletions extension/flat_tensor/flat_tensor_data_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
for (int i = 0; i < tensors->size(); i++) {
if (std::strcmp(tensors->Get(i)->fully_qualified_name()->c_str(), key) ==
0) {
// TODO(T214294528): Support multiple segments in FlatTensor.
if (tensors->Get(i)->segment_index() != 0) {
return Error::InvalidExternalData;
}
return tensors->Get(i);
}
}
Expand Down Expand Up @@ -97,31 +93,68 @@ ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
return metadata_res.error();
}
const auto metadata = metadata_res.get();
if (metadata->segment_index() < 0 || metadata->offset() < 0) {
// Invalid segment_index/offset; malformed PTD file.
return Error::InvalidExternalData;
}
ET_CHECK_OR_RETURN_ERROR(
metadata->segment_index() >= 0 && metadata->offset() >= 0,
InvalidExternalData,
"Invalid segment_index %d or offset %lu; malformed PTD file.",
metadata->segment_index(),
metadata->offset())

Result<const TensorLayout> tensor_layout_res = create_tensor_layout(metadata);
if (!tensor_layout_res.ok()) {
return tensor_layout_res.error();
Result<const TensorLayout> tensor_layout = create_tensor_layout(metadata);
if (!tensor_layout.ok()) {
return tensor_layout.error();
}

// This FreeableBuffer doesn't own the underlying data, and will not free it,
// which is why the free function is a nullptr.
// TODO(T214294528): Remove data_ro_ and instead load the data here, letting
// FreeableBuffer own it.
return FreeableBuffer(
static_cast<const uint8_t*>(data_ro_.data()) + metadata->offset(),
tensor_layout_res.get().nbytes(),
nullptr);
// Load constant data.
const auto* s_data_segment = flat_tensor_->segments();
int segment_offset = s_data_segment->Get(0)->offset();
return loader_->load(
header_.segment_base_offset + segment_offset + metadata->offset(),
tensor_layout.get().nbytes(),
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
}

ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
ET_UNUSED const char* key,
ET_UNUSED void* buffer,
ET_UNUSED size_t size) const {
return Error::NotImplemented;
auto tensor_metadata = flat_tensor_->tensors();

// Get metadata to get nbytes.
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
get_flat_tensor_metadata(key, tensor_metadata);
if (!metadata_res.ok()) {
return metadata_res.error();
}
const auto metadata = metadata_res.get();
ET_CHECK_OR_RETURN_ERROR(
metadata->segment_index() >= 0 && metadata->offset() >= 0,
InvalidExternalData,
"Invalid segment_index %d or offset %lu; malformed PTD file.",
metadata->segment_index(),
metadata->offset())

Result<const TensorLayout> tensor_layout = create_tensor_layout(metadata);
if (!tensor_layout.ok()) {
return tensor_layout.error();
}
ET_CHECK_OR_RETURN_ERROR(
size < tensor_layout.get().nbytes(),
InvalidArgument,
"Buffer size %zu is smaller than tensor size %zu",
size,
tensor_layout.get().nbytes())

const auto* s_data_segment = flat_tensor_->segments();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code is repeated and it's about 20 lines; I would extract a utility function here

Copy link
Contributor Author

@lucylq lucylq Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

int segment_offset = s_data_segment->Get(0)->offset();
DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
DataLoader::SegmentInfo::Type::Mutable, 0, nullptr);

return loader_->load_into(
header_.segment_base_offset + segment_offset + metadata->offset(),
tensor_layout.get().nbytes(),
info,
buffer);
}

ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
Expand All @@ -138,45 +171,34 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(

/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
DataLoader* loader) {
// Load data map.
size_t flatbuffer_offset = 0;
size_t flatbuffer_size = 0;
size_t segment_base_offset = 0;
size_t segment_data_size = 0;
{
// Check header.
Result<FreeableBuffer> header = loader->load(
/*offset=*/0,
FlatTensorHeader::kNumHeadBytes,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
if (!header.ok()) {
return header.error();
}
Result<FlatTensorHeader> fh =
FlatTensorHeader::Parse(header->data(), header->size());
if (fh.ok()) {
// The header has the data map size.
flatbuffer_offset = fh->flatbuffer_offset;
flatbuffer_size = fh->flatbuffer_size;
segment_base_offset = fh->segment_base_offset;
segment_data_size = fh->segment_data_size;
} else if (fh.error() == Error::NotFound) {
// No header, throw error.
ET_LOG(Error, "No FlatTensorHeader found.");
return fh.error();
} else {
// corruption, throw error.
ET_LOG(Error, "Flat tensor header may be corrupt.");
return fh.error();
}
// Check header.
Result<FreeableBuffer> header = loader->load(
/*offset=*/0,
FlatTensorHeader::kNumHeadBytes,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
if (!header.ok()) {
ET_LOG(Error, "Failed to load header.");
return header.error();
}
Result<FlatTensorHeader> fh =
FlatTensorHeader::Parse(header->data(), header->size());
if (fh.error() == Error::NotFound) {
// No header, throw error.
ET_LOG(Error, "No FlatTensorHeader found.");
return fh.error();
} else if (fh.error() != Error::Ok) {
// corruption, throw error.
ET_LOG(Error, "Flat tensor header may be corrupt.");
return fh.error();
}

// Load flatbuffer data as a segment.
Result<FreeableBuffer> flat_tensor_data = loader->load(
/*offset=*/0,
flatbuffer_offset + flatbuffer_size,
fh->flatbuffer_offset + fh->flatbuffer_size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
if (!flat_tensor_data.ok()) {
ET_LOG(Error, "Failed to load flat_tensor data.");
return flat_tensor_data.error();
}

Expand Down Expand Up @@ -204,54 +226,8 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data());

// Validate flatbuffer data.
flatbuffers::Verifier verifier(
reinterpret_cast<const uint8_t*>(flat_tensor_data->data()),
flat_tensor_data->size());
bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer(verifier);
ET_CHECK_OR_RETURN_ERROR(
ok,
InvalidExternalData,
"Verification failed; data may be truncated or corrupt");

// Get pointer to tensor metadata.
const auto* s_tensor_metadata = flat_tensor->tensors();
if (s_tensor_metadata == nullptr) {
ET_LOG(Error, "FlatTensor has no tensor metadata.");
return Error::InvalidExternalData;
}

// Load constant data.
const auto* s_data_segment = flat_tensor->segments();

// TODO(T214294528): Support multiple segments in FlatTensor.
if (s_data_segment->size() != 1) {
ET_LOG(
Error,
"FlatTensor has %u segments, only 1 supported.",
s_data_segment->size());
}
// First segment size should be <= the total segment data size.
int segment_size = s_data_segment->Get(0)->size();
int segment_offset = s_data_segment->Get(0)->offset();
if (segment_size > segment_data_size) {
ET_LOG(
Error,
"FlatTensor segment size %d > segment data size %zu",
segment_size,
segment_data_size);
}

Result<FreeableBuffer> data_ro = loader->load(
/*offset=*/segment_base_offset + segment_offset,
segment_size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
if (!data_ro.ok()) {
return data_ro.error();
}

return FlatTensorDataMap(
std::move(flat_tensor_data.get()), flat_tensor, std::move(data_ro.get()));
fh.get(), std::move(flat_tensor_data.get()), flat_tensor, loader);
}

} // namespace extension
Expand Down
50 changes: 45 additions & 5 deletions extension/flat_tensor/flat_tensor_data_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <executorch/runtime/core/named_data_map.h>

#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>

#include <executorch/runtime/core/data_loader.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/result.h>
Expand Down Expand Up @@ -41,17 +43,50 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
static executorch::runtime::Result<FlatTensorDataMap> load(
executorch::runtime::DataLoader* loader);

/**
* Retrieve the metadata for the specified key.
*
* @param[in] key The name of the tensor to get metadata on.
*
* @return Error::NotFound if the key is not present.
*/
ET_NODISCARD
executorch::runtime::Result<const executorch::runtime::TensorLayout>
get_metadata(const char* key) const override;

/**
* Retrieve read-only data for the specified key.
*
* @param[in] key The name of the tensor to get data on.
*
* @return error if the key is not present or data cannot be loaded.
*/
ET_NODISCARD
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
const char* key) const override;

/**
* Loads the data of the specified tensor into the provided buffer.
*
* @param[in] key The name of the tensor to get the data of.
* @param[in] buffer The buffer to load data into. Must point to at least
* `size` bytes of memory.
* @param[in] size The number of bytes to load.
*
* @returns an Error indicating if the load was successful.
*/
ET_NODISCARD executorch::runtime::Result<size_t>
load_data_into(const char* key, void* buffer, size_t size) const override;

/**
* @returns The number of keys in the map.
*/
ET_NODISCARD executorch::runtime::Result<size_t> get_num_keys()
const override;

/**
* @returns The key at the specified index, error if index out of bounds.
*/
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
size_t index) const override;

Expand All @@ -61,26 +96,31 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {

private:
FlatTensorDataMap(
const FlatTensorHeader& header,
executorch::runtime::FreeableBuffer&& flat_tensor_data,
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
executorch::runtime::FreeableBuffer&& data_ro)
: flat_tensor_data_(std::move(flat_tensor_data)),
executorch::runtime::DataLoader* loader)
: header_(header),
flat_tensor_data_(std::move(flat_tensor_data)),
flat_tensor_(flat_tensor),
data_ro_(std::move(data_ro)) {}
loader_(loader) {}

// Not copyable or assignable.
FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete;
FlatTensorDataMap& operator=(FlatTensorDataMap&& rhs) noexcept = delete;
FlatTensorDataMap& operator=(const FlatTensorDataMap& rhs) = delete;

// FlatTensor header, containing segment_base_offset and segment_data_size.
const FlatTensorHeader header_;

// Serialized flat_tensor flatbuffer data.
executorch::runtime::FreeableBuffer flat_tensor_data_;

// Flatbuffer representation of the flat_tensor.
const flat_tensor_flatbuffer::FlatTensor* flat_tensor_;

// Loaded read-only tensor data.
executorch::runtime::FreeableBuffer data_ro_;
// Data loader, used to load segment data.
executorch::runtime::DataLoader* loader_;
};

} // namespace extension
Expand Down
2 changes: 1 addition & 1 deletion extension/flat_tensor/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def define_common_targets(is_fbcode=False):
}

runtime.cxx_test(
name = "flat_tensor_data_map",
name = "flat_tensor_data_map_test",
srcs = [
"flat_tensor_data_map_test.cpp",
],
Expand Down
Loading