diff --git a/extension/flat_tensor/flat_tensor_data_map.cpp b/extension/flat_tensor/flat_tensor_data_map.cpp index 20ebc99994a..ff526e359d4 100644 --- a/extension/flat_tensor/flat_tensor_data_map.cpp +++ b/extension/flat_tensor/flat_tensor_data_map.cpp @@ -52,11 +52,14 @@ Result get_flat_tensor_metadata( for (int i = 0; i < tensors->size(); i++) { if (std::strcmp(tensors->Get(i)->fully_qualified_name()->c_str(), key) == 0) { - // TODO(T214294528): Support multiple segments in FlatTensor. - if (tensors->Get(i)->segment_index() != 0) { - return Error::InvalidExternalData; - } - return tensors->Get(i); + const auto* metadata = tensors->Get(i); + ET_CHECK_OR_RETURN_ERROR( + metadata->segment_index() >= 0 && metadata->offset() >= 0, + InvalidExternalData, + "Invalid segment_index %d or offset %" PRIu64 "; malformed PTD file.", + metadata->segment_index(), + metadata->offset()); + return metadata; } } return Error::NotFound; @@ -75,6 +78,23 @@ Result create_tensor_layout( scalar_type); } +Result get_and_check_segment_offset( + const flatbuffers::Vector< + flatbuffers::Offset>* segments, + const flat_tensor_flatbuffer::TensorMetadata* metadata) { + ET_CHECK_OR_RETURN_ERROR( + segments != nullptr, + InvalidExternalData, + "No segments in external data flatbuffer."); + + ET_CHECK_OR_RETURN_ERROR( + metadata->segment_index() < segments->size(), + InvalidExternalData, + "Invalid segment_index %d; malformed PTD file.", + metadata->segment_index()); + return segments->Get(metadata->segment_index())->offset(); +} + } // namespace ET_NODISCARD Result FlatTensorDataMap::get_metadata( @@ -89,39 +109,73 @@ ET_NODISCARD Result FlatTensorDataMap::get_metadata( ET_NODISCARD Result FlatTensorDataMap::get_data( const char* key) const { - auto tensor_metadata = flat_tensor_->tensors(); - - Result metadata_res = - get_flat_tensor_metadata(key, tensor_metadata); - if (!metadata_res.ok()) { - return metadata_res.error(); + Result metadata = + get_flat_tensor_metadata(key, flat_tensor_->tensors()); + if (!metadata.ok()) { + return metadata.error(); } - const auto metadata = metadata_res.get(); - if (metadata->segment_index() < 0 || metadata->offset() < 0) { - // Invalid segment_index/offset; malformed PTD file. - return Error::InvalidExternalData; + Result tensor_layout = + create_tensor_layout(metadata.get()); + if (!tensor_layout.ok()) { + return tensor_layout.error(); } - - Result tensor_layout_res = create_tensor_layout(metadata); - if (!tensor_layout_res.ok()) { - return tensor_layout_res.error(); + Result segment_offset = + get_and_check_segment_offset(flat_tensor_->segments(), metadata.get()); + if (!segment_offset.ok()) { + return segment_offset.error(); } - // This FreeableBuffer doesn't own the underlying data, and will not free it, - // which is why the free function is a nullptr. - // TODO(T214294528): Remove data_ro_ and instead load the data here, letting - // FreeableBuffer own it. - return FreeableBuffer( - static_cast(data_ro_.data()) + metadata->offset(), - tensor_layout_res.get().nbytes(), - nullptr); + // Load constant data. + ET_CHECK_OR_RETURN_ERROR( + segment_offset.get() < + header_.segment_base_offset + header_.segment_data_size, + InvalidExternalData, + "Invalid segment offset %d is larger than the segment_base_offset + segment_data_size %" PRIu64 + "; malformed PTD file.", + segment_offset.get(), + header_.segment_base_offset + header_.segment_data_size); + return loader_->load( + header_.segment_base_offset + segment_offset.get() + + metadata.get()->offset(), + tensor_layout.get().nbytes(), + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); } ET_NODISCARD Result FlatTensorDataMap::load_data_into( ET_UNUSED const char* key, ET_UNUSED void* buffer, ET_UNUSED size_t size) const { - return Error::NotImplemented; + Result metadata = + get_flat_tensor_metadata(key, flat_tensor_->tensors()); + if (!metadata.ok()) { + return metadata.error(); + } + Result tensor_layout = + create_tensor_layout(metadata.get()); + if (!tensor_layout.ok()) { + return tensor_layout.error(); + } + ET_CHECK_OR_RETURN_ERROR( + size < tensor_layout.get().nbytes(), + InvalidArgument, + "Buffer size %zu is smaller than tensor size %zu", + size, + tensor_layout.get().nbytes()); + + Result segment_offset = + get_and_check_segment_offset(flat_tensor_->segments(), metadata.get()); + if (!segment_offset.ok()) { + return segment_offset.error(); + } + // Load mutable data. + DataLoader::SegmentInfo info = DataLoader::SegmentInfo( + DataLoader::SegmentInfo::Type::Mutable, 0, nullptr); + return loader_->load_into( + header_.segment_base_offset + segment_offset.get() + + metadata.get()->offset(), + tensor_layout.get().nbytes(), + info, + buffer); } ET_NODISCARD Result FlatTensorDataMap::get_num_keys() const { @@ -138,45 +192,34 @@ ET_NODISCARD Result FlatTensorDataMap::get_key( /* static */ Result FlatTensorDataMap::load( DataLoader* loader) { - // Load data map. - size_t flatbuffer_offset = 0; - size_t flatbuffer_size = 0; - size_t segment_base_offset = 0; - size_t segment_data_size = 0; - { - // Check header. - Result header = loader->load( - /*offset=*/0, - FlatTensorHeader::kNumHeadBytes, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); - if (!header.ok()) { - return header.error(); - } - Result fh = - FlatTensorHeader::Parse(header->data(), header->size()); - if (fh.ok()) { - // The header has the data map size. - flatbuffer_offset = fh->flatbuffer_offset; - flatbuffer_size = fh->flatbuffer_size; - segment_base_offset = fh->segment_base_offset; - segment_data_size = fh->segment_data_size; - } else if (fh.error() == Error::NotFound) { - // No header, throw error. - ET_LOG(Error, "No FlatTensorHeader found."); - return fh.error(); - } else { - // corruption, throw error. - ET_LOG(Error, "Flat tensor header may be corrupt."); - return fh.error(); - } + // Check header. + Result header = loader->load( + /*offset=*/0, + FlatTensorHeader::kNumHeadBytes, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); + if (!header.ok()) { + ET_LOG(Error, "Failed to load header."); + return header.error(); + } + Result fh = + FlatTensorHeader::Parse(header->data(), header->size()); + if (fh.error() == Error::NotFound) { + // No header, throw error. + ET_LOG(Error, "No FlatTensorHeader found."); + return fh.error(); + } else if (fh.error() != Error::Ok) { + // corruption, throw error. + ET_LOG(Error, "Flat tensor header may be corrupt."); + return fh.error(); } // Load flatbuffer data as a segment. Result flat_tensor_data = loader->load( /*offset=*/0, - flatbuffer_offset + flatbuffer_size, + fh->flatbuffer_offset + fh->flatbuffer_size, DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); if (!flat_tensor_data.ok()) { + ET_LOG(Error, "Failed to load flat_tensor data."); return flat_tensor_data.error(); } @@ -204,54 +247,8 @@ ET_NODISCARD Result FlatTensorDataMap::get_key( const flat_tensor_flatbuffer::FlatTensor* flat_tensor = flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data()); - // Validate flatbuffer data. - flatbuffers::Verifier verifier( - reinterpret_cast(flat_tensor_data->data()), - flat_tensor_data->size()); - bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer(verifier); - ET_CHECK_OR_RETURN_ERROR( - ok, - InvalidExternalData, - "Verification failed; data may be truncated or corrupt"); - - // Get pointer to tensor metadata. - const auto* s_tensor_metadata = flat_tensor->tensors(); - if (s_tensor_metadata == nullptr) { - ET_LOG(Error, "FlatTensor has no tensor metadata."); - return Error::InvalidExternalData; - } - - // Load constant data. - const auto* s_data_segment = flat_tensor->segments(); - - // TODO(T214294528): Support multiple segments in FlatTensor. - if (s_data_segment->size() != 1) { - ET_LOG( - Error, - "FlatTensor has %u segments, only 1 supported.", - s_data_segment->size()); - } - // First segment size should be <= the total segment data size. - int segment_size = s_data_segment->Get(0)->size(); - int segment_offset = s_data_segment->Get(0)->offset(); - if (segment_size > segment_data_size) { - ET_LOG( - Error, - "FlatTensor segment size %d > segment data size %zu", - segment_size, - segment_data_size); - } - - Result data_ro = loader->load( - /*offset=*/segment_base_offset + segment_offset, - segment_size, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); - if (!data_ro.ok()) { - return data_ro.error(); - } - return FlatTensorDataMap( - std::move(flat_tensor_data.get()), flat_tensor, std::move(data_ro.get())); + fh.get(), std::move(flat_tensor_data.get()), flat_tensor, loader); } } // namespace extension diff --git a/extension/flat_tensor/flat_tensor_data_map.h b/extension/flat_tensor/flat_tensor_data_map.h index 7bd33e68927..00f4bf07d19 100644 --- a/extension/flat_tensor/flat_tensor_data_map.h +++ b/extension/flat_tensor/flat_tensor_data_map.h @@ -10,6 +10,8 @@ #include +#include + #include #include #include @@ -41,17 +43,50 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap { static executorch::runtime::Result load( executorch::runtime::DataLoader* loader); + /** + * Retrieve the metadata for the specified key. + * + * @param[in] key The name of the tensor to get metadata on. + * + * @return Error::NotFound if the key is not present. + */ ET_NODISCARD executorch::runtime::Result get_metadata(const char* key) const override; + + /** + * Retrieve read-only data for the specified key. + * + * @param[in] key The name of the tensor to get data on. + * + * @return error if the key is not present or data cannot be loaded. + */ ET_NODISCARD executorch::runtime::Result get_data( const char* key) const override; + + /** + * Loads the data of the specified tensor into the provided buffer. + * + * @param[in] key The name of the tensor to get the data of. + * @param[in] buffer The buffer to load data into. Must point to at least + * `size` bytes of memory. + * @param[in] size The number of bytes to load. + * + * @returns an Error indicating if the load was successful. + */ ET_NODISCARD executorch::runtime::Result load_data_into(const char* key, void* buffer, size_t size) const override; + /** + * @returns The number of keys in the map. + */ ET_NODISCARD executorch::runtime::Result get_num_keys() const override; + + /** + * @returns The key at the specified index, error if index out of bounds. + */ ET_NODISCARD executorch::runtime::Result get_key( size_t index) const override; @@ -61,26 +96,31 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap { private: FlatTensorDataMap( + const FlatTensorHeader& header, executorch::runtime::FreeableBuffer&& flat_tensor_data, const flat_tensor_flatbuffer::FlatTensor* flat_tensor, - executorch::runtime::FreeableBuffer&& data_ro) - : flat_tensor_data_(std::move(flat_tensor_data)), + executorch::runtime::DataLoader* loader) + : header_(header), + flat_tensor_data_(std::move(flat_tensor_data)), flat_tensor_(flat_tensor), - data_ro_(std::move(data_ro)) {} + loader_(loader) {} // Not copyable or assignable. FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete; FlatTensorDataMap& operator=(FlatTensorDataMap&& rhs) noexcept = delete; FlatTensorDataMap& operator=(const FlatTensorDataMap& rhs) = delete; + // FlatTensor header, containing segment_base_offset and segment_data_size. + const FlatTensorHeader header_; + // Serialized flat_tensor flatbuffer data. executorch::runtime::FreeableBuffer flat_tensor_data_; // Flatbuffer representation of the flat_tensor. const flat_tensor_flatbuffer::FlatTensor* flat_tensor_; - // Loaded read-only tensor data. - executorch::runtime::FreeableBuffer data_ro_; + // Data loader, used to load segment data. + executorch::runtime::DataLoader* loader_; }; } // namespace extension diff --git a/extension/flat_tensor/test/targets.bzl b/extension/flat_tensor/test/targets.bzl index bc04edfbe1e..28baace3eeb 100644 --- a/extension/flat_tensor/test/targets.bzl +++ b/extension/flat_tensor/test/targets.bzl @@ -40,7 +40,7 @@ def define_common_targets(is_fbcode=False): } runtime.cxx_test( - name = "flat_tensor_data_map", + name = "flat_tensor_data_map_test", srcs = [ "flat_tensor_data_map_test.cpp", ],