Skip to content

Commit 5939287

Browse files
committed
[executorch][flat_tensor] implement load into and dont hold onto the segment
1. Implement load_into in FlatTensorDataMap 2. Do not persist 'data_ro' in the FlatTensorDataMap. From `get_data`, return the FreeableBuffer given by the data loader. TODO: add test for load_into. Differential Revision: [D69148652](https://our.internmc.facebook.com/intern/diff/D69148652/) ghstack-source-id: 266205806 Pull Request resolved: #8447
1 parent a098671 commit 5939287

File tree

3 files changed

+121
-105
lines changed

3 files changed

+121
-105
lines changed

extension/flat_tensor/flat_tensor_data_map.cpp

Lines changed: 75 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
5252
for (int i = 0; i < tensors->size(); i++) {
5353
if (std::strcmp(tensors->Get(i)->fully_qualified_name()->c_str(), key) ==
5454
0) {
55-
// TODO(T214294528): Support multiple segments in FlatTensor.
56-
if (tensors->Get(i)->segment_index() != 0) {
57-
return Error::InvalidExternalData;
58-
}
5955
return tensors->Get(i);
6056
}
6157
}
@@ -97,31 +93,68 @@ ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
9793
return metadata_res.error();
9894
}
9995
const auto metadata = metadata_res.get();
100-
if (metadata->segment_index() < 0 || metadata->offset() < 0) {
101-
// Invalid segment_index/offset; malformed PTD file.
102-
return Error::InvalidExternalData;
103-
}
96+
ET_CHECK_OR_RETURN_ERROR(
97+
metadata->segment_index() >= 0 && metadata->offset() >= 0,
98+
InvalidExternalData,
99+
"Invalid segment_index %d or offset %lu; malformed PTD file.",
100+
metadata->segment_index(),
101+
metadata->offset())
104102

105-
Result<const TensorLayout> tensor_layout_res = create_tensor_layout(metadata);
106-
if (!tensor_layout_res.ok()) {
107-
return tensor_layout_res.error();
103+
Result<const TensorLayout> tensor_layout = create_tensor_layout(metadata);
104+
if (!tensor_layout.ok()) {
105+
return tensor_layout.error();
108106
}
109107

110-
// This FreeableBuffer doesn't own the underlying data, and will not free it,
111-
// which is why the free function is a nullptr.
112-
// TODO(T214294528): Remove data_ro_ and instead load the data here, letting
113-
// FreeableBuffer own it.
114-
return FreeableBuffer(
115-
static_cast<const uint8_t*>(data_ro_.data()) + metadata->offset(),
116-
tensor_layout_res.get().nbytes(),
117-
nullptr);
108+
// Load constant data.
109+
const auto* s_data_segment = flat_tensor_->segments();
110+
int segment_offset = s_data_segment->Get(0)->offset();
111+
return loader_->load(
112+
header_.segment_base_offset + segment_offset + metadata->offset(),
113+
tensor_layout.get().nbytes(),
114+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
118115
}
119116

120117
ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
121118
ET_UNUSED const char* key,
122119
ET_UNUSED void* buffer,
123120
ET_UNUSED size_t size) const {
124-
return Error::NotImplemented;
121+
auto tensor_metadata = flat_tensor_->tensors();
122+
123+
// Get metadata to get nbytes.
124+
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
125+
get_flat_tensor_metadata(key, tensor_metadata);
126+
if (!metadata_res.ok()) {
127+
return metadata_res.error();
128+
}
129+
const auto metadata = metadata_res.get();
130+
ET_CHECK_OR_RETURN_ERROR(
131+
metadata->segment_index() >= 0 && metadata->offset() >= 0,
132+
InvalidExternalData,
133+
"Invalid segment_index %d or offset %lu; malformed PTD file.",
134+
metadata->segment_index(),
135+
metadata->offset())
136+
137+
Result<const TensorLayout> tensor_layout = create_tensor_layout(metadata);
138+
if (!tensor_layout.ok()) {
139+
return tensor_layout.error();
140+
}
141+
ET_CHECK_OR_RETURN_ERROR(
142+
size < tensor_layout.get().nbytes(),
143+
InvalidArgument,
144+
"Buffer size %zu is smaller than tensor size %zu",
145+
size,
146+
tensor_layout.get().nbytes())
147+
148+
const auto* s_data_segment = flat_tensor_->segments();
149+
int segment_offset = s_data_segment->Get(0)->offset();
150+
DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
151+
DataLoader::SegmentInfo::Type::Mutable, 0, nullptr);
152+
153+
return loader_->load_into(
154+
header_.segment_base_offset + segment_offset + metadata->offset(),
155+
tensor_layout.get().nbytes(),
156+
info,
157+
buffer);
125158
}
126159

127160
ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
@@ -138,45 +171,34 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
138171

139172
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(
140173
DataLoader* loader) {
141-
// Load data map.
142-
size_t flatbuffer_offset = 0;
143-
size_t flatbuffer_size = 0;
144-
size_t segment_base_offset = 0;
145-
size_t segment_data_size = 0;
146-
{
147-
// Check header.
148-
Result<FreeableBuffer> header = loader->load(
149-
/*offset=*/0,
150-
FlatTensorHeader::kNumHeadBytes,
151-
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
152-
if (!header.ok()) {
153-
return header.error();
154-
}
155-
Result<FlatTensorHeader> fh =
156-
FlatTensorHeader::Parse(header->data(), header->size());
157-
if (fh.ok()) {
158-
// The header has the data map size.
159-
flatbuffer_offset = fh->flatbuffer_offset;
160-
flatbuffer_size = fh->flatbuffer_size;
161-
segment_base_offset = fh->segment_base_offset;
162-
segment_data_size = fh->segment_data_size;
163-
} else if (fh.error() == Error::NotFound) {
164-
// No header, throw error.
165-
ET_LOG(Error, "No FlatTensorHeader found.");
166-
return fh.error();
167-
} else {
168-
// corruption, throw error.
169-
ET_LOG(Error, "Flat tensor header may be corrupt.");
170-
return fh.error();
171-
}
174+
// Check header.
175+
Result<FreeableBuffer> header = loader->load(
176+
/*offset=*/0,
177+
FlatTensorHeader::kNumHeadBytes,
178+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
179+
if (!header.ok()) {
180+
ET_LOG(Error, "Failed to load header.");
181+
return header.error();
182+
}
183+
Result<FlatTensorHeader> fh =
184+
FlatTensorHeader::Parse(header->data(), header->size());
185+
if (fh.error() == Error::NotFound) {
186+
// No header, throw error.
187+
ET_LOG(Error, "No FlatTensorHeader found.");
188+
return fh.error();
189+
} else if (fh.error() != Error::Ok) {
190+
// corruption, throw error.
191+
ET_LOG(Error, "Flat tensor header may be corrupt.");
192+
return fh.error();
172193
}
173194

174195
// Load flatbuffer data as a segment.
175196
Result<FreeableBuffer> flat_tensor_data = loader->load(
176197
/*offset=*/0,
177-
flatbuffer_offset + flatbuffer_size,
198+
fh->flatbuffer_offset + fh->flatbuffer_size,
178199
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
179200
if (!flat_tensor_data.ok()) {
201+
ET_LOG(Error, "Failed to load flat_tensor data.");
180202
return flat_tensor_data.error();
181203
}
182204

@@ -204,54 +226,8 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
204226
const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
205227
flat_tensor_flatbuffer::GetFlatTensor(flat_tensor_data->data());
206228

207-
// Validate flatbuffer data.
208-
flatbuffers::Verifier verifier(
209-
reinterpret_cast<const uint8_t*>(flat_tensor_data->data()),
210-
flat_tensor_data->size());
211-
bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer(verifier);
212-
ET_CHECK_OR_RETURN_ERROR(
213-
ok,
214-
InvalidExternalData,
215-
"Verification failed; data may be truncated or corrupt");
216-
217-
// Get pointer to tensor metadata.
218-
const auto* s_tensor_metadata = flat_tensor->tensors();
219-
if (s_tensor_metadata == nullptr) {
220-
ET_LOG(Error, "FlatTensor has no tensor metadata.");
221-
return Error::InvalidExternalData;
222-
}
223-
224-
// Load constant data.
225-
const auto* s_data_segment = flat_tensor->segments();
226-
227-
// TODO(T214294528): Support multiple segments in FlatTensor.
228-
if (s_data_segment->size() != 1) {
229-
ET_LOG(
230-
Error,
231-
"FlatTensor has %u segments, only 1 supported.",
232-
s_data_segment->size());
233-
}
234-
// First segment size should be <= the total segment data size.
235-
int segment_size = s_data_segment->Get(0)->size();
236-
int segment_offset = s_data_segment->Get(0)->offset();
237-
if (segment_size > segment_data_size) {
238-
ET_LOG(
239-
Error,
240-
"FlatTensor segment size %d > segment data size %zu",
241-
segment_size,
242-
segment_data_size);
243-
}
244-
245-
Result<FreeableBuffer> data_ro = loader->load(
246-
/*offset=*/segment_base_offset + segment_offset,
247-
segment_size,
248-
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
249-
if (!data_ro.ok()) {
250-
return data_ro.error();
251-
}
252-
253229
return FlatTensorDataMap(
254-
std::move(flat_tensor_data.get()), flat_tensor, std::move(data_ro.get()));
230+
fh.get(), std::move(flat_tensor_data.get()), flat_tensor, loader);
255231
}
256232

257233
} // namespace extension

extension/flat_tensor/flat_tensor_data_map.h

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <executorch/runtime/core/named_data_map.h>
1212

13+
#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>
14+
1315
#include <executorch/runtime/core/data_loader.h>
1416
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1517
#include <executorch/runtime/core/result.h>
@@ -41,17 +43,50 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
4143
static executorch::runtime::Result<FlatTensorDataMap> load(
4244
executorch::runtime::DataLoader* loader);
4345

46+
/**
47+
* Retrieve the metadata for the specified key.
48+
*
49+
* @param[in] key The name of the tensor to get metadata on.
50+
*
51+
* @return Error::NotFound if the key is not present.
52+
*/
4453
ET_NODISCARD
4554
executorch::runtime::Result<const executorch::runtime::TensorLayout>
4655
get_metadata(const char* key) const override;
56+
57+
/**
58+
* Retrieve read-only data for the specified key.
59+
*
60+
* @param[in] key The name of the tensor to get data on.
61+
*
62+
* @return error if the key is not present or data cannot be loaded.
63+
*/
4764
ET_NODISCARD
4865
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
4966
const char* key) const override;
67+
68+
/**
69+
* Loads the data of the specified tensor into the provided buffer.
70+
*
71+
* @param[in] key The name of the tensor to get the data of.
72+
* @param[in] buffer The buffer to load data into. Must point to at least
73+
* `size` bytes of memory.
74+
* @param[in] size The number of bytes to load.
75+
*
76+
* @returns an Error indicating if the load was successful.
77+
*/
5078
ET_NODISCARD executorch::runtime::Result<size_t>
5179
load_data_into(const char* key, void* buffer, size_t size) const override;
5280

81+
/**
82+
* @returns The number of keys in the map.
83+
*/
5384
ET_NODISCARD executorch::runtime::Result<size_t> get_num_keys()
5485
const override;
86+
87+
/**
88+
* @returns The key at the specified index, error if index out of bounds.
89+
*/
5590
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
5691
size_t index) const override;
5792

@@ -61,26 +96,31 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
6196

6297
private:
6398
FlatTensorDataMap(
99+
const FlatTensorHeader& header,
64100
executorch::runtime::FreeableBuffer&& flat_tensor_data,
65101
const flat_tensor_flatbuffer::FlatTensor* flat_tensor,
66-
executorch::runtime::FreeableBuffer&& data_ro)
67-
: flat_tensor_data_(std::move(flat_tensor_data)),
102+
executorch::runtime::DataLoader* loader)
103+
: header_(header),
104+
flat_tensor_data_(std::move(flat_tensor_data)),
68105
flat_tensor_(flat_tensor),
69-
data_ro_(std::move(data_ro)) {}
106+
loader_(loader) {}
70107

71108
// Not copyable or assignable.
72109
FlatTensorDataMap(const FlatTensorDataMap& rhs) = delete;
73110
FlatTensorDataMap& operator=(FlatTensorDataMap&& rhs) noexcept = delete;
74111
FlatTensorDataMap& operator=(const FlatTensorDataMap& rhs) = delete;
75112

113+
// FlatTensor header, containing segment_base_offset and segment_data_size.
114+
const FlatTensorHeader header_;
115+
76116
// Serialized flat_tensor flatbuffer data.
77117
executorch::runtime::FreeableBuffer flat_tensor_data_;
78118

79119
// Flatbuffer representation of the flat_tensor.
80120
const flat_tensor_flatbuffer::FlatTensor* flat_tensor_;
81121

82-
// Loaded read-only tensor data.
83-
executorch::runtime::FreeableBuffer data_ro_;
122+
// Data loader, used to load segment data.
123+
executorch::runtime::DataLoader* loader_;
84124
};
85125

86126
} // namespace extension

extension/flat_tensor/test/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def define_common_targets(is_fbcode=False):
4040
}
4141

4242
runtime.cxx_test(
43-
name = "flat_tensor_data_map",
43+
name = "flat_tensor_data_map_test",
4444
srcs = [
4545
"flat_tensor_data_map_test.cpp",
4646
],

0 commit comments

Comments
 (0)