@@ -52,10 +52,6 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
5252 for (int i = 0 ; i < tensors->size (); i++) {
5353 if (std::strcmp (tensors->Get (i)->fully_qualified_name ()->c_str (), key) ==
5454 0 ) {
55- // TODO(T214294528): Support multiple segments in FlatTensor.
56- if (tensors->Get (i)->segment_index () != 0 ) {
57- return Error::InvalidExternalData;
58- }
5955 return tensors->Get (i);
6056 }
6157 }
@@ -97,31 +93,68 @@ ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
9793 return metadata_res.error ();
9894 }
9995 const auto metadata = metadata_res.get ();
100- if (metadata->segment_index () < 0 || metadata->offset () < 0 ) {
101- // Invalid segment_index/offset; malformed PTD file.
102- return Error::InvalidExternalData;
103- }
96+ ET_CHECK_OR_RETURN_ERROR (
97+ metadata->segment_index () >= 0 && metadata->offset () >= 0 ,
98+ InvalidExternalData,
99+ " Invalid segment_index %d or offset %lu; malformed PTD file." ,
100+ metadata->segment_index (),
101+ metadata->offset ())
104102
105- Result<const TensorLayout> tensor_layout_res = create_tensor_layout (metadata);
106- if (!tensor_layout_res .ok ()) {
107- return tensor_layout_res .error ();
103+ Result<const TensorLayout> tensor_layout = create_tensor_layout (metadata);
104+ if (!tensor_layout .ok ()) {
105+ return tensor_layout .error ();
108106 }
109107
110- // This FreeableBuffer doesn't own the underlying data, and will not free it,
111- // which is why the free function is a nullptr.
112- // TODO(T214294528): Remove data_ro_ and instead load the data here, letting
113- // FreeableBuffer own it.
114- return FreeableBuffer (
115- static_cast <const uint8_t *>(data_ro_.data ()) + metadata->offset (),
116- tensor_layout_res.get ().nbytes (),
117- nullptr );
108+ // Load constant data.
109+ const auto * s_data_segment = flat_tensor_->segments ();
110+ int segment_offset = s_data_segment->Get (0 )->offset ();
111+ return loader_->load (
112+ header_.segment_base_offset + segment_offset + metadata->offset (),
113+ tensor_layout.get ().nbytes (),
114+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
118115}
119116
120117ET_NODISCARD Result<size_t > FlatTensorDataMap::load_data_into (
121118 ET_UNUSED const char * key,
122119 ET_UNUSED void * buffer,
123120 ET_UNUSED size_t size) const {
124- return Error::NotImplemented;
121+ auto tensor_metadata = flat_tensor_->tensors ();
122+
123+ // Get metadata to get nbytes.
124+ Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
125+ get_flat_tensor_metadata (key, tensor_metadata);
126+ if (!metadata_res.ok ()) {
127+ return metadata_res.error ();
128+ }
129+ const auto metadata = metadata_res.get ();
130+ ET_CHECK_OR_RETURN_ERROR (
131+ metadata->segment_index () >= 0 && metadata->offset () >= 0 ,
132+ InvalidExternalData,
133+ " Invalid segment_index %d or offset %lu; malformed PTD file." ,
134+ metadata->segment_index (),
135+ metadata->offset ())
136+
137+ Result<const TensorLayout> tensor_layout = create_tensor_layout (metadata);
138+ if (!tensor_layout.ok ()) {
139+ return tensor_layout.error ();
140+ }
141+ ET_CHECK_OR_RETURN_ERROR (
142+ size < tensor_layout.get ().nbytes (),
143+ InvalidArgument,
144+ " Buffer size %zu is smaller than tensor size %zu" ,
145+ size,
146+ tensor_layout.get ().nbytes ())
147+
148+ const auto * s_data_segment = flat_tensor_->segments ();
149+ int segment_offset = s_data_segment->Get (0 )->offset ();
150+ DataLoader::SegmentInfo info = DataLoader::SegmentInfo (
151+ DataLoader::SegmentInfo::Type::Mutable, 0 , nullptr );
152+
153+ return loader_->load_into (
154+ header_.segment_base_offset + segment_offset + metadata->offset (),
155+ tensor_layout.get ().nbytes (),
156+ info,
157+ buffer);
125158}
126159
127160ET_NODISCARD Result<size_t > FlatTensorDataMap::get_num_keys () const {
@@ -138,45 +171,34 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
138171
139172/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load (
140173 DataLoader* loader) {
141- // Load data map.
142- size_t flatbuffer_offset = 0 ;
143- size_t flatbuffer_size = 0 ;
144- size_t segment_base_offset = 0 ;
145- size_t segment_data_size = 0 ;
146- {
147- // Check header.
148- Result<FreeableBuffer> header = loader->load (
149- /* offset=*/ 0 ,
150- FlatTensorHeader::kNumHeadBytes ,
151- DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
152- if (!header.ok ()) {
153- return header.error ();
154- }
155- Result<FlatTensorHeader> fh =
156- FlatTensorHeader::Parse (header->data (), header->size ());
157- if (fh.ok ()) {
158- // The header has the data map size.
159- flatbuffer_offset = fh->flatbuffer_offset ;
160- flatbuffer_size = fh->flatbuffer_size ;
161- segment_base_offset = fh->segment_base_offset ;
162- segment_data_size = fh->segment_data_size ;
163- } else if (fh.error () == Error::NotFound) {
164- // No header, throw error.
165- ET_LOG (Error, " No FlatTensorHeader found." );
166- return fh.error ();
167- } else {
168- // corruption, throw error.
169- ET_LOG (Error, " Flat tensor header may be corrupt." );
170- return fh.error ();
171- }
174+ // Check header.
175+ Result<FreeableBuffer> header = loader->load (
176+ /* offset=*/ 0 ,
177+ FlatTensorHeader::kNumHeadBytes ,
178+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
179+ if (!header.ok ()) {
180+ ET_LOG (Error, " Failed to load header." );
181+ return header.error ();
182+ }
183+ Result<FlatTensorHeader> fh =
184+ FlatTensorHeader::Parse (header->data (), header->size ());
185+ if (fh.error () == Error::NotFound) {
186+ // No header, throw error.
187+ ET_LOG (Error, " No FlatTensorHeader found." );
188+ return fh.error ();
189+ } else if (fh.error () != Error::Ok) {
190+ // corruption, throw error.
191+ ET_LOG (Error, " Flat tensor header may be corrupt." );
192+ return fh.error ();
172193 }
173194
174195 // Load flatbuffer data as a segment.
175196 Result<FreeableBuffer> flat_tensor_data = loader->load (
176197 /* offset=*/ 0 ,
177- flatbuffer_offset + flatbuffer_size,
198+ fh-> flatbuffer_offset + fh-> flatbuffer_size ,
178199 DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
179200 if (!flat_tensor_data.ok ()) {
201+ ET_LOG (Error, " Failed to load flat_tensor data." );
180202 return flat_tensor_data.error ();
181203 }
182204
@@ -204,54 +226,8 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
204226 const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
205227 flat_tensor_flatbuffer::GetFlatTensor (flat_tensor_data->data ());
206228
207- // Validate flatbuffer data.
208- flatbuffers::Verifier verifier (
209- reinterpret_cast <const uint8_t *>(flat_tensor_data->data ()),
210- flat_tensor_data->size ());
211- bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer (verifier);
212- ET_CHECK_OR_RETURN_ERROR (
213- ok,
214- InvalidExternalData,
215- " Verification failed; data may be truncated or corrupt" );
216-
217- // Get pointer to tensor metadata.
218- const auto * s_tensor_metadata = flat_tensor->tensors ();
219- if (s_tensor_metadata == nullptr ) {
220- ET_LOG (Error, " FlatTensor has no tensor metadata." );
221- return Error::InvalidExternalData;
222- }
223-
224- // Load constant data.
225- const auto * s_data_segment = flat_tensor->segments ();
226-
227- // TODO(T214294528): Support multiple segments in FlatTensor.
228- if (s_data_segment->size () != 1 ) {
229- ET_LOG (
230- Error,
231- " FlatTensor has %u segments, only 1 supported." ,
232- s_data_segment->size ());
233- }
234- // First segment size should be <= the total segment data size.
235- int segment_size = s_data_segment->Get (0 )->size ();
236- int segment_offset = s_data_segment->Get (0 )->offset ();
237- if (segment_size > segment_data_size) {
238- ET_LOG (
239- Error,
240- " FlatTensor segment size %d > segment data size %zu" ,
241- segment_size,
242- segment_data_size);
243- }
244-
245- Result<FreeableBuffer> data_ro = loader->load (
246- /* offset=*/ segment_base_offset + segment_offset,
247- segment_size,
248- DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
249- if (!data_ro.ok ()) {
250- return data_ro.error ();
251- }
252-
253229 return FlatTensorDataMap (
254- std::move (flat_tensor_data .get ()), flat_tensor, std::move (data_ro .get ()));
230+ fh .get (), std::move (flat_tensor_data .get ()), flat_tensor, loader );
255231}
256232
257233} // namespace extension
0 commit comments