@@ -52,11 +52,14 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
5252 for (int i = 0 ; i < tensors->size (); i++) {
5353 if (std::strcmp (tensors->Get (i)->fully_qualified_name ()->c_str (), key) ==
5454 0 ) {
55- // TODO(T214294528): Support multiple segments in FlatTensor.
56- if (tensors->Get (i)->segment_index () != 0 ) {
57- return Error::InvalidExternalData;
58- }
59- return tensors->Get (i);
55+ const auto * metadata = tensors->Get (i);
56+ ET_CHECK_OR_RETURN_ERROR (
57+ metadata->segment_index () >= 0 && metadata->offset () >= 0 ,
58+ InvalidExternalData,
59+ " Invalid segment_index %d or offset %lu; malformed PTD file." ,
60+ metadata->segment_index (),
61+ metadata->offset ());
62+ return metadata;
6063 }
6164 }
6265 return Error::NotFound;
@@ -89,39 +92,58 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
8992
9093ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data (
9194 const char * key) const {
92- auto tensor_metadata = flat_tensor_->tensors ();
93-
94- Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
95- get_flat_tensor_metadata (key, tensor_metadata);
96- if (!metadata_res.ok ()) {
97- return metadata_res.error ();
98- }
99- const auto metadata = metadata_res.get ();
100- if (metadata->segment_index () < 0 || metadata->offset () < 0 ) {
101- // Invalid segment_index/offset; malformed PTD file.
102- return Error::InvalidExternalData;
95+ Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
96+ get_flat_tensor_metadata (key, flat_tensor_->tensors ());
97+ if (!metadata.ok ()) {
98+ return metadata.error ();
10399 }
104-
105- Result< const TensorLayout> tensor_layout_res = create_tensor_layout (metadata);
106- if (!tensor_layout_res .ok ()) {
107- return tensor_layout_res .error ();
100+ Result< const TensorLayout> tensor_layout =
101+ create_tensor_layout (metadata. get () );
102+ if (!tensor_layout .ok ()) {
103+ return tensor_layout .error ();
108104 }
109105
110- // This FreeableBuffer doesn't own the underlying data, and will not free it,
111- // which is why the free function is a nullptr.
112- // TODO(T214294528): Remove data_ro_ and instead load the data here, letting
113- // FreeableBuffer own it.
114- return FreeableBuffer (
115- static_cast <const uint8_t *>(data_ro_.data ()) + metadata->offset (),
116- tensor_layout_res.get ().nbytes (),
117- nullptr );
106+ // Load constant data.
107+ int segment_offset =
108+ flat_tensor_->segments ()->Get (metadata.get ()->segment_index ())->offset ();
109+ return loader_->load (
110+ header_.segment_base_offset + segment_offset + metadata.get ()->offset (),
111+ tensor_layout.get ().nbytes (),
112+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
118113}
119114
120115ET_NODISCARD Result<size_t > FlatTensorDataMap::load_data_into (
121116 ET_UNUSED const char * key,
122117 ET_UNUSED void * buffer,
123118 ET_UNUSED size_t size) const {
124- return Error::NotImplemented;
119+ // Get metadata to get nbytes.
120+ Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
121+ get_flat_tensor_metadata (key, flat_tensor_->tensors ());
122+ if (!metadata.ok ()) {
123+ return metadata.error ();
124+ }
125+ Result<const TensorLayout> tensor_layout =
126+ create_tensor_layout (metadata.get ());
127+ if (!tensor_layout.ok ()) {
128+ return tensor_layout.error ();
129+ }
130+ ET_CHECK_OR_RETURN_ERROR (
131+ size < tensor_layout.get ().nbytes (),
132+ InvalidArgument,
133+ " Buffer size %zu is smaller than tensor size %zu" ,
134+ size,
135+ tensor_layout.get ().nbytes ())
136+
137+ int segment_offset =
138+ flat_tensor_->segments ()->Get (metadata.get ()->segment_index ())->offset ();
139+ DataLoader::SegmentInfo info = DataLoader::SegmentInfo (
140+ DataLoader::SegmentInfo::Type::Mutable, 0 , nullptr );
141+
142+ return loader_->load_into (
143+ header_.segment_base_offset + segment_offset + metadata.get ()->offset (),
144+ tensor_layout.get ().nbytes (),
145+ info,
146+ buffer);
125147}
126148
127149ET_NODISCARD Result<size_t > FlatTensorDataMap::get_num_keys () const {
@@ -138,45 +160,34 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
138160
139161/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load (
140162 DataLoader* loader) {
141- // Load data map.
142- size_t flatbuffer_offset = 0 ;
143- size_t flatbuffer_size = 0 ;
144- size_t segment_base_offset = 0 ;
145- size_t segment_data_size = 0 ;
146- {
147- // Check header.
148- Result<FreeableBuffer> header = loader->load (
149- /* offset=*/ 0 ,
150- FlatTensorHeader::kNumHeadBytes ,
151- DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
152- if (!header.ok ()) {
153- return header.error ();
154- }
155- Result<FlatTensorHeader> fh =
156- FlatTensorHeader::Parse (header->data (), header->size ());
157- if (fh.ok ()) {
158- // The header has the data map size.
159- flatbuffer_offset = fh->flatbuffer_offset ;
160- flatbuffer_size = fh->flatbuffer_size ;
161- segment_base_offset = fh->segment_base_offset ;
162- segment_data_size = fh->segment_data_size ;
163- } else if (fh.error () == Error::NotFound) {
164- // No header, throw error.
165- ET_LOG (Error, " No FlatTensorHeader found." );
166- return fh.error ();
167- } else {
168- // corruption, throw error.
169- ET_LOG (Error, " Flat tensor header may be corrupt." );
170- return fh.error ();
171- }
163+ // Check header.
164+ Result<FreeableBuffer> header = loader->load (
165+ /* offset=*/ 0 ,
166+ FlatTensorHeader::kNumHeadBytes ,
167+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
168+ if (!header.ok ()) {
169+ ET_LOG (Error, " Failed to load header." );
170+ return header.error ();
171+ }
172+ Result<FlatTensorHeader> fh =
173+ FlatTensorHeader::Parse (header->data (), header->size ());
174+ if (fh.error () == Error::NotFound) {
175+ // No header, throw error.
176+ ET_LOG (Error, " No FlatTensorHeader found." );
177+ return fh.error ();
178+ } else if (fh.error () != Error::Ok) {
179+ // corruption, throw error.
180+ ET_LOG (Error, " Flat tensor header may be corrupt." );
181+ return fh.error ();
172182 }
173183
174184 // Load flatbuffer data as a segment.
175185 Result<FreeableBuffer> flat_tensor_data = loader->load (
176186 /* offset=*/ 0 ,
177- flatbuffer_offset + flatbuffer_size,
187+ fh-> flatbuffer_offset + fh-> flatbuffer_size ,
178188 DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
179189 if (!flat_tensor_data.ok ()) {
190+ ET_LOG (Error, " Failed to load flat_tensor data." );
180191 return flat_tensor_data.error ();
181192 }
182193
@@ -204,54 +215,8 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
204215 const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
205216 flat_tensor_flatbuffer::GetFlatTensor (flat_tensor_data->data ());
206217
207- // Validate flatbuffer data.
208- flatbuffers::Verifier verifier (
209- reinterpret_cast <const uint8_t *>(flat_tensor_data->data ()),
210- flat_tensor_data->size ());
211- bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer (verifier);
212- ET_CHECK_OR_RETURN_ERROR (
213- ok,
214- InvalidExternalData,
215- " Verification failed; data may be truncated or corrupt" );
216-
217- // Get pointer to tensor metadata.
218- const auto * s_tensor_metadata = flat_tensor->tensors ();
219- if (s_tensor_metadata == nullptr ) {
220- ET_LOG (Error, " FlatTensor has no tensor metadata." );
221- return Error::InvalidExternalData;
222- }
223-
224- // Load constant data.
225- const auto * s_data_segment = flat_tensor->segments ();
226-
227- // TODO(T214294528): Support multiple segments in FlatTensor.
228- if (s_data_segment->size () != 1 ) {
229- ET_LOG (
230- Error,
231- " FlatTensor has %u segments, only 1 supported." ,
232- s_data_segment->size ());
233- }
234- // First segment size should be <= the total segment data size.
235- int segment_size = s_data_segment->Get (0 )->size ();
236- int segment_offset = s_data_segment->Get (0 )->offset ();
237- if (segment_size > segment_data_size) {
238- ET_LOG (
239- Error,
240- " FlatTensor segment size %d > segment data size %zu" ,
241- segment_size,
242- segment_data_size);
243- }
244-
245- Result<FreeableBuffer> data_ro = loader->load (
246- /* offset=*/ segment_base_offset + segment_offset,
247- segment_size,
248- DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
249- if (!data_ro.ok ()) {
250- return data_ro.error ();
251- }
252-
253218 return FlatTensorDataMap (
254- std::move (flat_tensor_data .get ()), flat_tensor, std::move (data_ro .get ()));
219+ fh .get (), std::move (flat_tensor_data .get ()), flat_tensor, loader );
255220}
256221
257222} // namespace extension
0 commit comments