@@ -52,11 +52,14 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
5252  for  (int  i = 0 ; i < tensors->size (); i++) {
5353    if  (std::strcmp (tensors->Get (i)->fully_qualified_name ()->c_str (), key) ==
5454        0 ) {
55-       //  TODO(T214294528): Support multiple segments in FlatTensor.
56-       if  (tensors->Get (i)->segment_index () != 0 ) {
57-         return  Error::InvalidExternalData;
58-       }
59-       return  tensors->Get (i);
55+       const  auto * metadata = tensors->Get (i);
56+       ET_CHECK_OR_RETURN_ERROR (
57+           metadata->segment_index () >= 0  && metadata->offset () >= 0 ,
58+           InvalidExternalData,
59+           " Invalid segment_index %d or offset %"   PRIu64 " ; malformed PTD file."  ,
60+           metadata->segment_index (),
61+           metadata->offset ());
62+       return  metadata;
6063    }
6164  }
6265  return  Error::NotFound;
@@ -75,6 +78,23 @@ Result<const TensorLayout> create_tensor_layout(
7578      scalar_type);
7679}
7780
81+ Result<int > get_and_check_segment_offset (
82+     const  flatbuffers::Vector<
83+         flatbuffers::Offset<flat_tensor_flatbuffer::DataSegment>>* segments,
84+     const  flat_tensor_flatbuffer::TensorMetadata* metadata) {
85+   ET_CHECK_OR_RETURN_ERROR (
86+       segments != nullptr ,
87+       InvalidExternalData,
88+       " No segments in external data flatbuffer."  );
89+ 
90+   ET_CHECK_OR_RETURN_ERROR (
91+       metadata->segment_index () < segments->size (),
92+       InvalidExternalData,
93+       " Invalid segment_index %d; malformed PTD file."  ,
94+       metadata->segment_index ());
95+   return  segments->Get (metadata->segment_index ())->offset ();
96+ }
97+ 
7898} //  namespace
7999
80100ET_NODISCARD Result<const  TensorLayout> FlatTensorDataMap::get_metadata (
@@ -89,39 +109,73 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
89109
90110ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data (
91111    const  char * key) const  {
92-   auto  tensor_metadata = flat_tensor_->tensors ();
93- 
94-   Result<const  flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
95-       get_flat_tensor_metadata (key, tensor_metadata);
96-   if  (!metadata_res.ok ()) {
97-     return  metadata_res.error ();
112+   Result<const  flat_tensor_flatbuffer::TensorMetadata*> metadata =
113+       get_flat_tensor_metadata (key, flat_tensor_->tensors ());
114+   if  (!metadata.ok ()) {
115+     return  metadata.error ();
98116  }
99-   const  auto  metadata = metadata_res. get (); 
100-   if  (metadata-> segment_index () <  0  ||  metadata-> offset () <  0 ) { 
101-      //  Invalid segment_index/offset; malformed PTD file. 
102-     return  Error::InvalidExternalData ;
117+   Result< const  TensorLayout> tensor_layout = 
118+        create_tensor_layout ( metadata. get ()); 
119+   if  (!tensor_layout. ok ()) { 
120+     return  tensor_layout. error () ;
103121  }
104- 
105-   Result< const  TensorLayout> tensor_layout_res =  create_tensor_layout ( metadata);
106-   if  (!tensor_layout_res .ok ()) {
107-     return  tensor_layout_res .error ();
122+   Result< int > segment_offset = 
123+        get_and_check_segment_offset (flat_tensor_-> segments (),  metadata. get () );
124+   if  (!segment_offset .ok ()) {
125+     return  segment_offset .error ();
108126  }
109127
110-   //  This FreeableBuffer doesn't own the underlying data, and will not free it,
111-   //  which is why the free function is a nullptr.
112-   //  TODO(T214294528): Remove data_ro_ and instead load the data here, letting
113-   //  FreeableBuffer own it.
114-   return  FreeableBuffer (
115-       static_cast <const  uint8_t *>(data_ro_.data ()) + metadata->offset (),
116-       tensor_layout_res.get ().nbytes (),
117-       nullptr );
128+   //  Load constant data.
129+   ET_CHECK_OR_RETURN_ERROR (
130+       segment_offset.get () <
131+           header_.segment_base_offset  + header_.segment_data_size ,
132+       InvalidExternalData,
133+       " Invalid segment offset %d is larger than the segment_base_offset + segment_data_size %"   PRIu64
134+       " ; malformed PTD file."  ,
135+       segment_offset.get (),
136+       header_.segment_base_offset  + header_.segment_data_size );
137+   return  loader_->load (
138+       header_.segment_base_offset  + segment_offset.get () +
139+           metadata.get ()->offset (),
140+       tensor_layout.get ().nbytes (),
141+       DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
118142}
119143
120144ET_NODISCARD Result<size_t > FlatTensorDataMap::load_data_into (
121145    ET_UNUSED const  char * key,
122146    ET_UNUSED void * buffer,
123147    ET_UNUSED size_t  size) const  {
124-   return  Error::NotImplemented;
148+   Result<const  flat_tensor_flatbuffer::TensorMetadata*> metadata =
149+       get_flat_tensor_metadata (key, flat_tensor_->tensors ());
150+   if  (!metadata.ok ()) {
151+     return  metadata.error ();
152+   }
153+   Result<const  TensorLayout> tensor_layout =
154+       create_tensor_layout (metadata.get ());
155+   if  (!tensor_layout.ok ()) {
156+     return  tensor_layout.error ();
157+   }
158+   ET_CHECK_OR_RETURN_ERROR (
159+       size < tensor_layout.get ().nbytes (),
160+       InvalidArgument,
161+       " Buffer size %zu is smaller than tensor size %zu"  ,
162+       size,
163+       tensor_layout.get ().nbytes ());
164+ 
165+   Result<int > segment_offset =
166+       get_and_check_segment_offset (flat_tensor_->segments (), metadata.get ());
167+   if  (!segment_offset.ok ()) {
168+     return  segment_offset.error ();
169+   }
170+   //  Load mutable data.
171+   DataLoader::SegmentInfo info = DataLoader::SegmentInfo (
172+       DataLoader::SegmentInfo::Type::Mutable, 0 , nullptr );
173+   return  loader_->load_into (
174+       header_.segment_base_offset  + segment_offset.get () +
175+           metadata.get ()->offset (),
176+       tensor_layout.get ().nbytes (),
177+       info,
178+       buffer);
125179}
126180
127181ET_NODISCARD Result<size_t > FlatTensorDataMap::get_num_keys () const  {
@@ -138,45 +192,34 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
138192
139193/*  static */   Result<FlatTensorDataMap> FlatTensorDataMap::load (
140194    DataLoader* loader) {
141-   //  Load data map.
142-   size_t  flatbuffer_offset = 0 ;
143-   size_t  flatbuffer_size = 0 ;
144-   size_t  segment_base_offset = 0 ;
145-   size_t  segment_data_size = 0 ;
146-   {
147-     //  Check header.
148-     Result<FreeableBuffer> header = loader->load (
149-         /* offset=*/ 0 ,
150-         FlatTensorHeader::kNumHeadBytes ,
151-         DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
152-     if  (!header.ok ()) {
153-       return  header.error ();
154-     }
155-     Result<FlatTensorHeader> fh =
156-         FlatTensorHeader::Parse (header->data (), header->size ());
157-     if  (fh.ok ()) {
158-       //  The header has the data map size.
159-       flatbuffer_offset = fh->flatbuffer_offset ;
160-       flatbuffer_size = fh->flatbuffer_size ;
161-       segment_base_offset = fh->segment_base_offset ;
162-       segment_data_size = fh->segment_data_size ;
163-     } else  if  (fh.error () == Error::NotFound) {
164-       //  No header, throw error.
165-       ET_LOG (Error, " No FlatTensorHeader found."  );
166-       return  fh.error ();
167-     } else  {
168-       //  corruption, throw error.
169-       ET_LOG (Error, " Flat tensor header may be corrupt."  );
170-       return  fh.error ();
171-     }
195+   //  Check header.
196+   Result<FreeableBuffer> header = loader->load (
197+       /* offset=*/ 0 ,
198+       FlatTensorHeader::kNumHeadBytes ,
199+       DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
200+   if  (!header.ok ()) {
201+     ET_LOG (Error, " Failed to load header."  );
202+     return  header.error ();
203+   }
204+   Result<FlatTensorHeader> fh =
205+       FlatTensorHeader::Parse (header->data (), header->size ());
206+   if  (fh.error () == Error::NotFound) {
207+     //  No header, throw error.
208+     ET_LOG (Error, " No FlatTensorHeader found."  );
209+     return  fh.error ();
210+   } else  if  (fh.error () != Error::Ok) {
211+     //  corruption, throw error.
212+     ET_LOG (Error, " Flat tensor header may be corrupt."  );
213+     return  fh.error ();
172214  }
173215
174216  //  Load flatbuffer data as a segment.
175217  Result<FreeableBuffer> flat_tensor_data = loader->load (
176218      /* offset=*/ 0 ,
177-       flatbuffer_offset + flatbuffer_size,
219+       fh-> flatbuffer_offset  + fh-> flatbuffer_size ,
178220      DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
179221  if  (!flat_tensor_data.ok ()) {
222+     ET_LOG (Error, " Failed to load flat_tensor data."  );
180223    return  flat_tensor_data.error ();
181224  }
182225
@@ -204,54 +247,8 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
204247  const  flat_tensor_flatbuffer::FlatTensor* flat_tensor =
205248      flat_tensor_flatbuffer::GetFlatTensor (flat_tensor_data->data ());
206249
207-   //  Validate flatbuffer data.
208-   flatbuffers::Verifier verifier (
209-       reinterpret_cast <const  uint8_t *>(flat_tensor_data->data ()),
210-       flat_tensor_data->size ());
211-   bool  ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer (verifier);
212-   ET_CHECK_OR_RETURN_ERROR (
213-       ok,
214-       InvalidExternalData,
215-       " Verification failed; data may be truncated or corrupt"  );
216- 
217-   //  Get pointer to tensor metadata.
218-   const  auto * s_tensor_metadata = flat_tensor->tensors ();
219-   if  (s_tensor_metadata == nullptr ) {
220-     ET_LOG (Error, " FlatTensor has no tensor metadata."  );
221-     return  Error::InvalidExternalData;
222-   }
223- 
224-   //  Load constant data.
225-   const  auto * s_data_segment = flat_tensor->segments ();
226- 
227-   //  TODO(T214294528): Support multiple segments in FlatTensor.
228-   if  (s_data_segment->size () != 1 ) {
229-     ET_LOG (
230-         Error,
231-         " FlatTensor has %u segments, only 1 supported."  ,
232-         s_data_segment->size ());
233-   }
234-   //  First segment size should be <= the total segment data size.
235-   int  segment_size = s_data_segment->Get (0 )->size ();
236-   int  segment_offset = s_data_segment->Get (0 )->offset ();
237-   if  (segment_size > segment_data_size) {
238-     ET_LOG (
239-         Error,
240-         " FlatTensor segment size %d > segment data size %zu"  ,
241-         segment_size,
242-         segment_data_size);
243-   }
244- 
245-   Result<FreeableBuffer> data_ro = loader->load (
246-       /* offset=*/  segment_base_offset + segment_offset,
247-       segment_size,
248-       DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
249-   if  (!data_ro.ok ()) {
250-     return  data_ro.error ();
251-   }
252- 
253250  return  FlatTensorDataMap (
254-       std::move (flat_tensor_data .get ()), flat_tensor,  std::move (data_ro .get ()));
251+       fh .get (),  std::move (flat_tensor_data .get ()), flat_tensor, loader );
255252}
256253
257254} //  namespace extension
0 commit comments