@@ -52,10 +52,6 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
5252  for  (int  i = 0 ; i < tensors->size (); i++) {
5353    if  (std::strcmp (tensors->Get (i)->fully_qualified_name ()->c_str (), key) ==
5454        0 ) {
55-       //  TODO(T214294528): Support multiple segments in FlatTensor.
56-       if  (tensors->Get (i)->segment_index () != 0 ) {
57-         return  Error::InvalidExternalData;
58-       }
5955      return  tensors->Get (i);
6056    }
6157  }
@@ -97,31 +93,68 @@ ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
9793    return  metadata_res.error ();
9894  }
9995  const  auto  metadata = metadata_res.get ();
100-   if  (metadata->segment_index () < 0  || metadata->offset () < 0 ) {
101-     //  Invalid segment_index/offset; malformed PTD file.
102-     return  Error::InvalidExternalData;
103-   }
96+   ET_CHECK_OR_RETURN_ERROR (
97+       metadata->segment_index () >= 0  && metadata->offset () >= 0 ,
98+       InvalidExternalData,
99+       " Invalid segment_index %d or offset %lu; malformed PTD file."  ,
100+       metadata->segment_index (),
101+       metadata->offset ())
104102
105-   Result<const  TensorLayout> tensor_layout_res  = create_tensor_layout (metadata);
106-   if  (!tensor_layout_res .ok ()) {
107-     return  tensor_layout_res .error ();
103+   Result<const  TensorLayout> tensor_layout  = create_tensor_layout (metadata);
104+   if  (!tensor_layout .ok ()) {
105+     return  tensor_layout .error ();
108106  }
109107
110-   //  This FreeableBuffer doesn't own the underlying data, and will not free it,
111-   //  which is why the free function is a nullptr.
112-   //  TODO(T214294528): Remove data_ro_ and instead load the data here, letting
113-   //  FreeableBuffer own it.
114-   return  FreeableBuffer (
115-       static_cast <const  uint8_t *>(data_ro_.data ()) + metadata->offset (),
116-       tensor_layout_res.get ().nbytes (),
117-       nullptr );
108+   //  Load constant data.
109+   const  auto * s_data_segment = flat_tensor_->segments ();
110+   int  segment_offset = s_data_segment->Get (0 )->offset ();
111+   return  loader_->load (
112+       header_.segment_base_offset  + segment_offset + metadata->offset (),
113+       tensor_layout.get ().nbytes (),
114+       DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
118115}
119116
120117ET_NODISCARD Result<size_t > FlatTensorDataMap::load_data_into (
121118    ET_UNUSED const  char * key,
122119    ET_UNUSED void * buffer,
123120    ET_UNUSED size_t  size) const  {
124-   return  Error::NotImplemented;
121+   auto  tensor_metadata = flat_tensor_->tensors ();
122+ 
123+   //  Get metadata to get nbytes.
124+   Result<const  flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
125+       get_flat_tensor_metadata (key, tensor_metadata);
126+   if  (!metadata_res.ok ()) {
127+     return  metadata_res.error ();
128+   }
129+   const  auto  metadata = metadata_res.get ();
130+   ET_CHECK_OR_RETURN_ERROR (
131+       metadata->segment_index () >= 0  && metadata->offset () >= 0 ,
132+       InvalidExternalData,
133+       " Invalid segment_index %d or offset %lu; malformed PTD file."  ,
134+       metadata->segment_index (),
135+       metadata->offset ())
136+ 
137+   Result<const  TensorLayout> tensor_layout = create_tensor_layout (metadata);
138+   if  (!tensor_layout.ok ()) {
139+     return  tensor_layout.error ();
140+   }
141+   ET_CHECK_OR_RETURN_ERROR (
142+       size < tensor_layout.get ().nbytes (),
143+       InvalidArgument,
144+       " Buffer size %zu is smaller than tensor size %zu"  ,
145+       size,
146+       tensor_layout.get ().nbytes ())
147+ 
148+   const  auto * s_data_segment = flat_tensor_->segments ();
149+   int  segment_offset = s_data_segment->Get (0 )->offset ();
150+   DataLoader::SegmentInfo info = DataLoader::SegmentInfo (
151+       DataLoader::SegmentInfo::Type::Mutable, 0 , nullptr );
152+ 
153+   return  loader_->load_into (
154+       header_.segment_base_offset  + segment_offset + metadata->offset (),
155+       tensor_layout.get ().nbytes (),
156+       info,
157+       buffer);
125158}
126159
127160ET_NODISCARD Result<size_t > FlatTensorDataMap::get_num_keys () const  {
@@ -138,45 +171,34 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
138171
139172/*  static */   Result<FlatTensorDataMap> FlatTensorDataMap::load (
140173    DataLoader* loader) {
141-   //  Load data map.
142-   size_t  flatbuffer_offset = 0 ;
143-   size_t  flatbuffer_size = 0 ;
144-   size_t  segment_base_offset = 0 ;
145-   size_t  segment_data_size = 0 ;
146-   {
147-     //  Check header.
148-     Result<FreeableBuffer> header = loader->load (
149-         /* offset=*/ 0 ,
150-         FlatTensorHeader::kNumHeadBytes ,
151-         DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
152-     if  (!header.ok ()) {
153-       return  header.error ();
154-     }
155-     Result<FlatTensorHeader> fh =
156-         FlatTensorHeader::Parse (header->data (), header->size ());
157-     if  (fh.ok ()) {
158-       //  The header has the data map size.
159-       flatbuffer_offset = fh->flatbuffer_offset ;
160-       flatbuffer_size = fh->flatbuffer_size ;
161-       segment_base_offset = fh->segment_base_offset ;
162-       segment_data_size = fh->segment_data_size ;
163-     } else  if  (fh.error () == Error::NotFound) {
164-       //  No header, throw error.
165-       ET_LOG (Error, " No FlatTensorHeader found."  );
166-       return  fh.error ();
167-     } else  {
168-       //  corruption, throw error.
169-       ET_LOG (Error, " Flat tensor header may be corrupt."  );
170-       return  fh.error ();
171-     }
174+   //  Check header.
175+   Result<FreeableBuffer> header = loader->load (
176+       /* offset=*/ 0 ,
177+       FlatTensorHeader::kNumHeadBytes ,
178+       DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
179+   if  (!header.ok ()) {
180+     ET_LOG (Error, " Failed to load header."  );
181+     return  header.error ();
182+   }
183+   Result<FlatTensorHeader> fh =
184+       FlatTensorHeader::Parse (header->data (), header->size ());
185+   if  (fh.error () == Error::NotFound) {
186+     //  No header, throw error.
187+     ET_LOG (Error, " No FlatTensorHeader found."  );
188+     return  fh.error ();
189+   } else  if  (fh.error () != Error::Ok) {
190+     //  corruption, throw error.
191+     ET_LOG (Error, " Flat tensor header may be corrupt."  );
192+     return  fh.error ();
172193  }
173194
174195  //  Load flatbuffer data as a segment.
175196  Result<FreeableBuffer> flat_tensor_data = loader->load (
176197      /* offset=*/ 0 ,
177-       flatbuffer_offset + flatbuffer_size,
198+       fh-> flatbuffer_offset  + fh-> flatbuffer_size ,
178199      DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
179200  if  (!flat_tensor_data.ok ()) {
201+     ET_LOG (Error, " Failed to load flat_tensor data."  );
180202    return  flat_tensor_data.error ();
181203  }
182204
@@ -204,54 +226,8 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
204226  const  flat_tensor_flatbuffer::FlatTensor* flat_tensor =
205227      flat_tensor_flatbuffer::GetFlatTensor (flat_tensor_data->data ());
206228
207-   //  Validate flatbuffer data.
208-   flatbuffers::Verifier verifier (
209-       reinterpret_cast <const  uint8_t *>(flat_tensor_data->data ()),
210-       flat_tensor_data->size ());
211-   bool  ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer (verifier);
212-   ET_CHECK_OR_RETURN_ERROR (
213-       ok,
214-       InvalidExternalData,
215-       " Verification failed; data may be truncated or corrupt"  );
216- 
217-   //  Get pointer to tensor metadata.
218-   const  auto * s_tensor_metadata = flat_tensor->tensors ();
219-   if  (s_tensor_metadata == nullptr ) {
220-     ET_LOG (Error, " FlatTensor has no tensor metadata."  );
221-     return  Error::InvalidExternalData;
222-   }
223- 
224-   //  Load constant data.
225-   const  auto * s_data_segment = flat_tensor->segments ();
226- 
227-   //  TODO(T214294528): Support multiple segments in FlatTensor.
228-   if  (s_data_segment->size () != 1 ) {
229-     ET_LOG (
230-         Error,
231-         " FlatTensor has %u segments, only 1 supported."  ,
232-         s_data_segment->size ());
233-   }
234-   //  First segment size should be <= the total segment data size.
235-   int  segment_size = s_data_segment->Get (0 )->size ();
236-   int  segment_offset = s_data_segment->Get (0 )->offset ();
237-   if  (segment_size > segment_data_size) {
238-     ET_LOG (
239-         Error,
240-         " FlatTensor segment size %d > segment data size %zu"  ,
241-         segment_size,
242-         segment_data_size);
243-   }
244- 
245-   Result<FreeableBuffer> data_ro = loader->load (
246-       /* offset=*/  segment_base_offset + segment_offset,
247-       segment_size,
248-       DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
249-   if  (!data_ro.ok ()) {
250-     return  data_ro.error ();
251-   }
252- 
253229  return  FlatTensorDataMap (
254-       std::move (flat_tensor_data .get ()), flat_tensor,  std::move (data_ro .get ()));
230+       fh .get (),  std::move (flat_tensor_data .get ()), flat_tensor, loader );
255231}
256232
257233} //  namespace extension
0 commit comments