@@ -52,11 +52,14 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
5252 for (int i = 0 ; i < tensors->size (); i++) {
5353 if (std::strcmp (tensors->Get (i)->fully_qualified_name ()->c_str (), key) ==
5454 0 ) {
55- // TODO(T214294528): Support multiple segments in FlatTensor.
56- if (tensors->Get (i)->segment_index () != 0 ) {
57- return Error::InvalidExternalData;
58- }
59- return tensors->Get (i);
55+ const auto * metadata = tensors->Get (i);
56+ ET_CHECK_OR_RETURN_ERROR (
57+ metadata->segment_index () >= 0 && metadata->offset () >= 0 ,
58+ InvalidExternalData,
59+ " Invalid segment_index %d or offset %lu; malformed PTD file." ,
60+ metadata->segment_index (),
61+ metadata->offset ());
62+ return metadata;
6063 }
6164 }
6265 return Error::NotFound;
@@ -75,6 +78,23 @@ Result<const TensorLayout> create_tensor_layout(
7578 scalar_type);
7679}
7780
81+ Result<int > get_and_check_segment_offset (
82+ const flatbuffers::Vector<
83+ flatbuffers::Offset<flat_tensor_flatbuffer::DataSegment>>* segments,
84+ const flat_tensor_flatbuffer::TensorMetadata* metadata) {
85+ ET_CHECK_OR_RETURN_ERROR (
86+ segments != nullptr ,
87+ InvalidExternalData,
88+ " No segments in external data flatbuffer." );
89+
90+ ET_CHECK_OR_RETURN_ERROR (
91+ metadata->segment_index () < segments->size (),
92+ InvalidExternalData,
93+ " Invalid segment_index %d; malformed PTD file." ,
94+ metadata->segment_index ());
95+ return segments->Get (metadata->segment_index ())->offset ();
96+ }
97+
7898} // namespace
7999
80100ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata (
@@ -89,39 +109,72 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
89109
90110ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data (
91111 const char * key) const {
92- auto tensor_metadata = flat_tensor_->tensors ();
93-
94- Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata_res =
95- get_flat_tensor_metadata (key, tensor_metadata);
96- if (!metadata_res.ok ()) {
97- return metadata_res.error ();
112+ Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
113+ get_flat_tensor_metadata (key, flat_tensor_->tensors ());
114+ if (!metadata.ok ()) {
115+ return metadata.error ();
98116 }
99- const auto metadata = metadata_res. get ();
100- if (metadata-> segment_index () < 0 || metadata-> offset () < 0 ) {
101- // Invalid segment_index/offset; malformed PTD file.
102- return Error::InvalidExternalData ;
117+ Result< const TensorLayout> tensor_layout =
118+ create_tensor_layout ( metadata. get ());
119+ if (!tensor_layout. ok ()) {
120+ return tensor_layout. error () ;
103121 }
104-
105- Result< const TensorLayout> tensor_layout_res = create_tensor_layout ( metadata);
106- if (!tensor_layout_res .ok ()) {
107- return tensor_layout_res .error ();
122+ Result< int > segment_offset =
123+ get_and_check_segment_offset (flat_tensor_-> segments (), metadata. get () );
124+ if (!segment_offset .ok ()) {
125+ return segment_offset .error ();
108126 }
109127
110- // This FreeableBuffer doesn't own the underlying data, and will not free it,
111- // which is why the free function is a nullptr.
112- // TODO(T214294528): Remove data_ro_ and instead load the data here, letting
113- // FreeableBuffer own it.
114- return FreeableBuffer (
115- static_cast <const uint8_t *>(data_ro_.data ()) + metadata->offset (),
116- tensor_layout_res.get ().nbytes (),
117- nullptr );
128+ // Load constant data.
129+ ET_CHECK_OR_RETURN_ERROR (
130+ segment_offset.get () <
131+ header_.segment_base_offset + header_.segment_data_size ,
132+ InvalidExternalData,
133+ " Invalid segment offset %d is larger than the segment_base_offset + segment_data_size %lu; malformed PTD file." ,
134+ segment_offset.get (),
135+ header_.segment_base_offset + header_.segment_data_size );
136+ return loader_->load (
137+ header_.segment_base_offset + segment_offset.get () +
138+ metadata.get ()->offset (),
139+ tensor_layout.get ().nbytes (),
140+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
118141}
119142
120143ET_NODISCARD Result<size_t > FlatTensorDataMap::load_data_into (
121144 ET_UNUSED const char * key,
122145 ET_UNUSED void * buffer,
123146 ET_UNUSED size_t size) const {
124- return Error::NotImplemented;
147+ Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
148+ get_flat_tensor_metadata (key, flat_tensor_->tensors ());
149+ if (!metadata.ok ()) {
150+ return metadata.error ();
151+ }
152+ Result<const TensorLayout> tensor_layout =
153+ create_tensor_layout (metadata.get ());
154+ if (!tensor_layout.ok ()) {
155+ return tensor_layout.error ();
156+ }
157+ ET_CHECK_OR_RETURN_ERROR (
158+ size < tensor_layout.get ().nbytes (),
159+ InvalidArgument,
160+ " Buffer size %zu is smaller than tensor size %zu" ,
161+ size,
162+ tensor_layout.get ().nbytes ());
163+
164+ Result<int > segment_offset =
165+ get_and_check_segment_offset (flat_tensor_->segments (), metadata.get ());
166+ if (!segment_offset.ok ()) {
167+ return segment_offset.error ();
168+ }
169+ // Load mutable data.
170+ DataLoader::SegmentInfo info = DataLoader::SegmentInfo (
171+ DataLoader::SegmentInfo::Type::Mutable, 0 , nullptr );
172+ return loader_->load_into (
173+ header_.segment_base_offset + segment_offset.get () +
174+ metadata.get ()->offset (),
175+ tensor_layout.get ().nbytes (),
176+ info,
177+ buffer);
125178}
126179
127180ET_NODISCARD Result<size_t > FlatTensorDataMap::get_num_keys () const {
@@ -138,45 +191,34 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
138191
139192/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load (
140193 DataLoader* loader) {
141- // Load data map.
142- size_t flatbuffer_offset = 0 ;
143- size_t flatbuffer_size = 0 ;
144- size_t segment_base_offset = 0 ;
145- size_t segment_data_size = 0 ;
146- {
147- // Check header.
148- Result<FreeableBuffer> header = loader->load (
149- /* offset=*/ 0 ,
150- FlatTensorHeader::kNumHeadBytes ,
151- DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
152- if (!header.ok ()) {
153- return header.error ();
154- }
155- Result<FlatTensorHeader> fh =
156- FlatTensorHeader::Parse (header->data (), header->size ());
157- if (fh.ok ()) {
158- // The header has the data map size.
159- flatbuffer_offset = fh->flatbuffer_offset ;
160- flatbuffer_size = fh->flatbuffer_size ;
161- segment_base_offset = fh->segment_base_offset ;
162- segment_data_size = fh->segment_data_size ;
163- } else if (fh.error () == Error::NotFound) {
164- // No header, throw error.
165- ET_LOG (Error, " No FlatTensorHeader found." );
166- return fh.error ();
167- } else {
168- // corruption, throw error.
169- ET_LOG (Error, " Flat tensor header may be corrupt." );
170- return fh.error ();
171- }
194+ // Check header.
195+ Result<FreeableBuffer> header = loader->load (
196+ /* offset=*/ 0 ,
197+ FlatTensorHeader::kNumHeadBytes ,
198+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
199+ if (!header.ok ()) {
200+ ET_LOG (Error, " Failed to load header." );
201+ return header.error ();
202+ }
203+ Result<FlatTensorHeader> fh =
204+ FlatTensorHeader::Parse (header->data (), header->size ());
205+ if (fh.error () == Error::NotFound) {
206+ // No header, throw error.
207+ ET_LOG (Error, " No FlatTensorHeader found." );
208+ return fh.error ();
209+ } else if (fh.error () != Error::Ok) {
210+ // corruption, throw error.
211+ ET_LOG (Error, " Flat tensor header may be corrupt." );
212+ return fh.error ();
172213 }
173214
174215 // Load flatbuffer data as a segment.
175216 Result<FreeableBuffer> flat_tensor_data = loader->load (
176217 /* offset=*/ 0 ,
177- flatbuffer_offset + flatbuffer_size,
218+ fh-> flatbuffer_offset + fh-> flatbuffer_size ,
178219 DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
179220 if (!flat_tensor_data.ok ()) {
221+ ET_LOG (Error, " Failed to load flat_tensor data." );
180222 return flat_tensor_data.error ();
181223 }
182224
@@ -204,54 +246,8 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
204246 const flat_tensor_flatbuffer::FlatTensor* flat_tensor =
205247 flat_tensor_flatbuffer::GetFlatTensor (flat_tensor_data->data ());
206248
207- // Validate flatbuffer data.
208- flatbuffers::Verifier verifier (
209- reinterpret_cast <const uint8_t *>(flat_tensor_data->data ()),
210- flat_tensor_data->size ());
211- bool ok = flat_tensor_flatbuffer::VerifyFlatTensorBuffer (verifier);
212- ET_CHECK_OR_RETURN_ERROR (
213- ok,
214- InvalidExternalData,
215- " Verification failed; data may be truncated or corrupt" );
216-
217- // Get pointer to tensor metadata.
218- const auto * s_tensor_metadata = flat_tensor->tensors ();
219- if (s_tensor_metadata == nullptr ) {
220- ET_LOG (Error, " FlatTensor has no tensor metadata." );
221- return Error::InvalidExternalData;
222- }
223-
224- // Load constant data.
225- const auto * s_data_segment = flat_tensor->segments ();
226-
227- // TODO(T214294528): Support multiple segments in FlatTensor.
228- if (s_data_segment->size () != 1 ) {
229- ET_LOG (
230- Error,
231- " FlatTensor has %u segments, only 1 supported." ,
232- s_data_segment->size ());
233- }
234- // First segment size should be <= the total segment data size.
235- int segment_size = s_data_segment->Get (0 )->size ();
236- int segment_offset = s_data_segment->Get (0 )->offset ();
237- if (segment_size > segment_data_size) {
238- ET_LOG (
239- Error,
240- " FlatTensor segment size %d > segment data size %zu" ,
241- segment_size,
242- segment_data_size);
243- }
244-
245- Result<FreeableBuffer> data_ro = loader->load (
246- /* offset=*/ segment_base_offset + segment_offset,
247- segment_size,
248- DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
249- if (!data_ro.ok ()) {
250- return data_ro.error ();
251- }
252-
253249 return FlatTensorDataMap (
254- std::move (flat_tensor_data .get ()), flat_tensor, std::move (data_ro .get ()));
250+ fh .get (), std::move (flat_tensor_data .get ()), flat_tensor, loader );
255251}
256252
257253} // namespace extension
0 commit comments