Skip to content

Commit 4dfb1d6

Browse files
committed
[executorch][flat_tensor] implement load into and dont hold onto the segment
Pull Request resolved: #8447 1. Implement load_into in FlatTensorDataMap 2. Do not persist 'data_ro' in the FlatTensorDataMap. From `get_data`, return the FreeableBuffer given by the data loader. TODO: add test for load_into. ghstack-source-id: 267313796 Differential Revision: [D69148652](https://our.internmc.facebook.com/intern/diff/D69148652/)
1 parent 35f3b8a commit 4dfb1d6

File tree

3 files changed

+119
-114
lines changed

3 files changed

+119
-114
lines changed

extension/flat_tensor/flat_tensor_data_map.cpp

Lines changed: 73 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,14 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
5252
for (int i = 0; i < tensors->size(); i++) {
5353
if (std::strcmp(tensors->Get(i)->fully_qualified_name()->c_str(), key) ==
5454
0) {
55-
// TODO(T214294528): Support multiple segments in FlatTensor.
56-
if (tensors->Get(i)->segment_index() != 0) {
57-
return Error::InvalidExternalData;
58-
}
59-
return tensors->Get(i);
55+
const auto* metadata = tensors->Get(i);
56+
ET_CHECK_OR_RETURN_ERROR(
57+
metadata->segment_index() >= 0 && metadata->offset() >= 0,
58+
InvalidExternalData,
59+
"Invalid segment_index %d or offset %lu; malformed PTD file.",
60+
metadata->segment_index(),
61+
metadata->offset());
62+
return metadata;
6063
}
6164
}
6265
return Error::NotFound;
@@ -89,39 +92,58 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
8992

9093
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
9194
const char* key) const {
92-
auto tensor_metadata = flat_tensor_->tensors();
93-
94-
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
95-
get_flat_tensor_metadata(key, tensor_metadata);
96-
if (!metadata_res.ok()) {
97-
return metadata_res.error();
98-
}
99-
const auto metadata = metadata_res.get();
100-
if (metadata->segment_index() < 0 || metadata->offset() < 0) {
101-
// Invalid segment_index/offset; malformed PTD file.
102-
return Error::InvalidExternalData;
95+
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
96+
get_flat_tensor_metadata(key, flat_tensor_->tensors());
97+
if (!metadata.ok()) {
98+
return metadata.error();
10399
}
104-
105-
Result<const TensorLayout> tensor_layout_res = create_tensor_layout(metadata);
106-
if (!tensor_layout_res.ok()) {
107-
return tensor_layout_res.error();
100+
Result<const TensorLayout> tensor_layout =
101+
create_tensor_layout(metadata.get());
102+
if (!tensor_layout.ok()) {
103+
return tensor_layout.error();
108104
}
109105

110-
// This FreeableBuffer doesn't own the underlying data, and will not free it,
111-
// which is why the free function is a nullptr.
112-
// TODO(T214294528): Remove data_ro_ and instead load the data here, letting
113-
// FreeableBuffer own it.
114-
return FreeableBuffer(
115-
static_cast<const uint8_t*>(data_ro_.data()) + metadata->offset(),
116-
tensor_layout_res.get().nbytes(),
117-
nullptr);
106+
// Load constant data.
107+
int segment_offset =
108+
flat_tensor_->segments()->Get(metadata.get()->segment_index())->offset();
109+
return loader_->load(
110+
header_.segment_base_offset + segment_offset + metadata.get()->offset(),
111+
tensor_layout.get().nbytes(),
112+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
118113
}
119114

120115
ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
121116
ET_UNUSED const char* key,
122117
ET_UNUSED void* buffer,
123118
ET_UNUSED size_t size) const {
124-
return Error::NotImplemented;
119+
// Get metadata to get nbytes.
120+
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
121+
get_flat_tensor_metadata(key, flat_tensor_->tensors());
122+
if (!metadata.ok()) {
123+
return metadata.error();
124+
}
125+
Result<const TensorLayout> tensor_layout =
126+
create_tensor_layout(metadata.get());
127+
if (!tensor_layout.ok()) {
128+
return tensor_layout.error();
129+
}
130+
ET_CHECK_OR_RETURN_ERROR(
131+
size < tensor_layout.get().nbytes(),
132+
InvalidArgument,
133+
"Buffer size %zu is smaller than tensor size %zu",
134+
size,
135+
tensor_layout.get().nbytes())
136+
137+
int segment_offset =
138+
flat_tensor_->segments()->Get(metadata.get()->segment_index())->offset();
139+
DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
140+
DataLoader::SegmentInfo::Type::Mutable, 0, nullptr);
141+
142+
return loader_->load_into(
143+
header_.segment_base_offset + segment_offset + metadata.get()->offset(),
144+
tensor_layout.get().nbytes(),
145+
info,
146+
buffer);
125147
}
126148

127149
ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
@@ -138,45 +160,34 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
138160

139161
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
140162
DataLoader* loader) {
141-
// Load data map.
142-
size_t flatbuffer_offset = 0;
143-
size_t flatbuffer_size = 0;
144-
size_t segment_base_offset = 0;
145-
size_t segment_data_size = 0;
146-
{
147-
// Check header.
148-
Result<FreeableBuffer> header = loader->load(
149-
/*offset=*/0,
150-
FlatTensorHeader::kNumHeadBytes,
151-
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
152-
if (!header.ok()) {
153-
return header.error();
154-
}
155-
Result<FlatTensorHeader> fh =
156-
FlatTensorHeader::Parse(header->data(), header->size());
157-
if (fh.ok()) {
158-
// The header has the data map size.
159-
flatbuffer_offset = fh->flatbuffer_offset;
160-
flatbuffer_size = fh->flatbuffer_size;
161-
segment_base_offset = fh->segment_base_offset;
162-
segment_data_size = fh->segment_data_size;
163-
} else if (fh.error() == Error::NotFound) {
164-
// No header, throw error.
165-
ET_LOG(Error, "No FlatTensorHeader found.");
166-
return fh.error();
167-
} else {
168-
// corruption, throw error.
169-
ET_LOG(Error, "Flat tensor header may be corrupt.");
170-
return fh.error();
171-
}
163+
// Check header.
164+
Result<FreeableBuffer> header = loader->load(
165+
/*offset=*/0,
166+
FlatTensorHeader::kNumHeadBytes,
167+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
168+
if (!header.ok()) {
169+
ET_LOG(Error, "Failed to load header.");
170+
return header.error();
171+
}
172+
Result<FlatTensorHeader> fh =
173+
FlatTensorHeader::Parse(header->data(), header->size());
174+
if (fh.error() == Error::NotFound) {
175+
// No header, throw error.
176+
ET_LOG(Error, "No FlatTensorHeader found.");
177+
return fh.error();
178+
} else if (fh.error() != Error::Ok) {
179+
// corruption, throw error.
180+
ET_LOG(Error, "Flat tensor header may be corrupt.");
181+
return fh.error();
172182
}
173183

174184
// Load flatbuffer data as a segment.
175185
Result<FreeableBuffer> flat_tensor_data = loader->load(
176186
/*offset=*/0,
177-
flatbuffer_offset + flatbuffer_size,
187+
fh->flatbuffer_offset + fh->flatbuffer_size,
178188
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
179189
if (!flat_tensor_data.ok()) {
190+
ET_LOG(Error, "Failed to load flat_tensor data.");
180191
return flat_tensor_data.error();
181192
}
182193

@@ -204,54 +215,8 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
204215
const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
205216
flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data());
206217

207-
// Validate flatbuffer data.
208-
flatbuffers::Verifier verifier(
209-
reinterpret_cast<const uint8_t*>(flat_tensor_data->data()),
210-
flat_tensor_data->size());
211-
bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer(verifier);
212-
ET_CHECK_OR_RETURN_ERROR(
213-
ok,
214-
InvalidExternalData,
215-
"Verification failed; data may be truncated or corrupt");
216-
217-
// Get pointer to tensor metadata.
218-
const auto* s_tensor_metadata = flat_tensor->tensors();
219-
if (s_tensor_metadata == nullptr) {
220-
ET_LOG(Error, "FlatTensor has no tensor metadata.");
221-
return Error::InvalidExternalData;
222-
}
223-
224-
// Load constant data.
225-
const auto* s_data_segment = flat_tensor->segments();
226-
227-
// TODO(T214294528): Support multiple segments in FlatTensor.
228-
if (s_data_segment->size() != 1) {
229-
ET_LOG(
230-
Error,
231-
"FlatTensor has %u segments, only 1 supported.",
232-
s_data_segment->size());
233-
}
234-
// First segment size should be <= the total segment data size.
235-
int segment_size = s_data_segment->Get(0)->size();
236-
int segment_offset = s_data_segment->Get(0)->offset();
237-
if (segment_size > segment_data_size) {
238-
ET_LOG(
239-
Error,
240-
"FlatTensor segment size %d > segment data size %zu",
241-
segment_size,
242-
segment_data_size);
243-
}
244-
245-
Result<FreeableBuffer> data_ro = loader->load(
246-
/*offset=*/segment_base_offset + segment_offset,
247-
segment_size,
248-
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
249-
if (!data_ro.ok()) {
250-
return data_ro.error();
251-
}
252-
253218
return FlatTensorDataMap(
254-
std::move(flat_tensor_data.get()), flat_tensor, std::move(data_ro.get()));
219+
fh.get(), std::move(flat_tensor_data.get()), flat_tensor, loader);
255220
}
256221

257222
} // namespace extension

extension/flat_tensor/flat_tensor_data_map.h

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <executorch/runtime/core/named_data_map.h>
1212

13+
#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>
14+
1315
#include <executorch/runtime/core/data_loader.h>
1416
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1517
#include <executorch/runtime/core/result.h>
@@ -41,17 +43,50 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
4143
static executorch::runtime::Result<FlatTensorDataMap> load(
4244
executorch::runtime::DataLoader* loader);
4345

46+
/**
47+
* Retrieve the metadata for the specified key.
48+
*
49+
* @param[in] key The name of the tensor to get metadata on.
50+
*
51+
* @return Error::NotFound if the key is not present.
52+
*/
4453
ET_NODISCARD
4554
executorch::runtime::Result<const executorch::runtime::TensorLayout>
4655
get_metadata(const char* key) const override;
56+
57+
/**
58+
* Retrieve read-only data for the specified key.
59+
*
60+
* @param[in] key The name of the tensor to get data on.
61+
*
62+
* @return error if the key is not present or data cannot be loaded.
63+
*/
4764
ET_NODISCARD
4865
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
4966
const char* key) const override;
67+
68+
/**
69+
* Loads the data of the specified tensor into the provided buffer.
70+
*
71+
* @param[in] key The name of the tensor to get the data of.
72+
* @param[in] buffer The buffer to load data into. Must point to at least
73+
* `size` bytes of memory.
74+
* @param[in] size The number of bytes to load.
75+
*
76+
* @returns an Error indicating if the load was successful.
77+
*/
5078
ET_NODISCARD executorch::runtime::Result<size_t>
5179
load_data_into(const char* key, void* buffer, size_t size) const override;
5280

81+
/**
82+
* @returns The number of keys in the map.
83+
*/
5384
ET_NODISCARD executorch::runtime::Result<size_t> get_num_keys()
5485
const override;
86+
87+
/**
88+
* @returns The key at the specified index, error if index out of bounds.
89+
*/
5590
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
5691
size_t index) const override;
5792

@@ -61,26 +96,31 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
6196

6297
private:
6398
FlatTensorDataMap(
99+
const FlatTensorHeader& header,
64100
executorch::runtime::FreeableBuffer&& flat_tensor_data,
65101
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
66-
executorch::runtime::FreeableBuffer&& data_ro)
67-
: flat_tensor_data_(std::move(flat_tensor_data)),
102+
executorch::runtime::DataLoader* loader)
103+
: header_(header),
104+
flat_tensor_data_(std::move(flat_tensor_data)),
68105
flat_tensor_(flat_tensor),
69-
data_ro_(std::move(data_ro)) {}
106+
loader_(loader) {}
70107

71108
// Not copyable or assignable.
72109
FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete;
73110
FlatTensorDataMap& operator=(FlatTensorDataMap&& rhs) noexcept = delete;
74111
FlatTensorDataMap& operator=(const FlatTensorDataMap& rhs) = delete;
75112

113+
// FlatTensor header, containing segment_base_offset and segment_data_size.
114+
const FlatTensorHeader header_;
115+
76116
// Serialized flat_tensor flatbuffer data.
77117
executorch::runtime::FreeableBuffer flat_tensor_data_;
78118

79119
// Flatbuffer representation of the flat_tensor.
80120
const flat_tensor_flatbuffer::FlatTensor* flat_tensor_;
81121

82-
// Loaded read-only tensor data.
83-
executorch::runtime::FreeableBuffer data_ro_;
122+
// Data loader, used to load segment data.
123+
executorch::runtime::DataLoader* loader_;
84124
};
85125

86126
} // namespace extension

extension/flat_tensor/test/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def define_common_targets(is_fbcode=False):
4040
}
4141

4242
runtime.cxx_test(
43-
name = "flat_tensor_data_map",
43+
name = "flat_tensor_data_map_test",
4444
srcs = [
4545
"flat_tensor_data_map_test.cpp",
4646
],

0 commit comments

Comments
 (0)