@@ -44,6 +44,28 @@ bool is_aligned(const void* data) {
4444 return addr % kMinimumAlignment == 0 ;
4545}
4646
47+ Result<const flat_tensor_flatbuffer::NamedData*> get_named_data (
48+ const char * key,
49+ const flatbuffers::Vector<
50+ flatbuffers::Offset<flat_tensor_flatbuffer::NamedData>>* named_data) {
51+ // Linear search by name.
52+ if (named_data == nullptr ) {
53+ return Error::NotFound;
54+ }
55+ for (int i = 0 ; i < named_data->size (); i++) {
56+ if (std::strcmp (named_data->Get (i)->key ()->c_str (), key) == 0 ) {
57+ const auto * metadata = named_data->Get (i);
58+ ET_CHECK_OR_RETURN_ERROR (
59+ metadata->segment_index () >= 0 ,
60+ InvalidExternalData,
61+ " Invalid segment_index %d; malformed PTD file." ,
62+ metadata->segment_index ());
63+ return metadata;
64+ }
65+ }
66+ return Error::NotFound;
67+ }
68+
4769Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata (
4870 const char * key,
4971 const flatbuffers::Vector<
@@ -109,6 +131,39 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
109131
110132ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data (
111133 const char * key) const {
134+ // TODO(lfq): consolidate named_data and tensors.
135+ // Check named data.
136+ Result<const flat_tensor_flatbuffer::NamedData*> named_data =
137+ get_named_data (key, flat_tensor_->named_data ());
138+ if (named_data.ok ()) {
139+ size_t segment_index = named_data.get ()->segment_index ();
140+ ET_CHECK_OR_RETURN_ERROR (
141+ segment_index < flat_tensor_->segments ()->size (),
142+ InvalidExternalData,
143+ " Invalid segment_index %zu; malformed PTD file." ,
144+ segment_index);
145+
146+ size_t segment_offset =
147+ flat_tensor_->segments ()->Get (segment_index)->offset ();
148+ size_t segment_size = flat_tensor_->segments ()->Get (segment_index)->size ();
149+ ET_CHECK_OR_RETURN_ERROR (
150+ segment_offset <
151+ header_.segment_base_offset + header_.segment_data_size ,
152+ InvalidExternalData,
153+ " Invalid segment offset %zu is larger than the segment_base_offset + segment_data_size %" PRIu64
154+ " ; malformed PTD file." ,
155+ segment_offset,
156+ header_.segment_base_offset + header_.segment_data_size );
157+ return loader_->load (
158+ /* offset=*/ header_.segment_base_offset + segment_offset,
159+ segment_size,
160+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
161+ }
162+ if (named_data.error () != Error::NotFound) {
163+ return named_data.error ();
164+ }
165+
166+ // Check tensors, if named data is not found.
112167 Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
113168 get_flat_tensor_metadata (key, flat_tensor_->tensors ());
114169 if (!metadata.ok ()) {
@@ -179,16 +234,34 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into(
179234}
180235
181236ET_NODISCARD Result<size_t > FlatTensorDataMap::get_num_keys () const {
182- return flat_tensor_->tensors ()->size ();
237+ // TODO(lfq): consolidate named_data and tensors.
238+ if (flat_tensor_->named_data () == nullptr ) {
239+ return flat_tensor_->tensors ()->size ();
240+ }
241+ return flat_tensor_->named_data ()->size () + flat_tensor_->tensors ()->size ();
183242}
184243
185244ET_NODISCARD Result<const char *> FlatTensorDataMap::get_key (
186245 size_t index) const {
187- if (index < 0 || index >= flat_tensor_->tensors ()->size ()) {
188- return Error::InvalidArgument;
189- }
246+ // TODO(lfq): consolidate named_data and tensors.
247+ // For now, iterate over named_data and then flat_tensor.
248+ size_t num_keys = get_num_keys ().get ();
249+ ET_CHECK_OR_RETURN_ERROR (
250+ index >= 0 && index < num_keys,
251+ InvalidArgument,
252+ " Index %zu out of range of size %zu" ,
253+ index,
254+ num_keys);
190255
191- return flat_tensor_->tensors ()->Get (index)->fully_qualified_name ()->c_str ();
256+ if (flat_tensor_->named_data () != nullptr &&
257+ index < flat_tensor_->named_data ()->size ()) {
258+ return flat_tensor_->named_data ()->Get (index)->key ()->c_str ();
259+ } else {
260+ if (flat_tensor_->named_data () != nullptr ) {
261+ index = index - flat_tensor_->named_data ()->size ();
262+ }
263+ return flat_tensor_->tensors ()->Get (index)->fully_qualified_name ()->c_str ();
264+ }
192265}
193266
194267/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load (
0 commit comments