1010#include  < executorch/backends/xnnpack/runtime/XNNHeader.h> 
1111#include  < executorch/backends/xnnpack/serialization/schema_generated.h> 
1212#include  < executorch/extension/threadpool/threadpool.h> 
13- #include  < executorch/runtime/core/exec_aten/util/scalar_type_util .h> 
13+ #include  < executorch/runtime/executor/pte_data_map .h> 
1414#include  < unordered_map> 
1515
1616#pragma  clang diagnostic ignored "-Wmissing-prototypes"
@@ -22,7 +22,9 @@ namespace xnnpack {
2222namespace  delegate  {
2323
2424using  executorch::runtime::Error;
25+ using  executorch::runtime::FreeableBuffer;
2526using  executorch::runtime::MemoryAllocator;
27+ using  executorch::runtime::NamedDataMap;
2628using  executorch::runtime::Result;
2729
2830/* 
@@ -48,6 +50,7 @@ class CompileAllocator {
4850using  ValuePtr = const  fb_xnnpack::XValue*;
4951using  NodePtr = const  fb_xnnpack::XNode*;
5052using  GraphPtr = const  fb_xnnpack::XNNGraph*;
53+ using  ConstantDataOffsetPtr = const  fb_xnnpack::ConstantDataOffset*;
5154using  DataType = fb_xnnpack::XNNDatatype;
5255
5356//  Type for define node function. This is the function signature
@@ -162,7 +165,9 @@ data associated with the tensor value, then returns nullptr.
162165const  uint8_t * getConstantDataPtr (
163166    const  fb_xnnpack::XNNTensorValue* tensor_value,
164167    GraphPtr flatbuffer_graph,
165-     const  uint8_t * constant_data_ptr) {
168+     const  uint8_t * constant_data_ptr,
169+     const  NamedDataMap* named_data_map,
170+     std::vector<FreeableBuffer>& loaded_buffers_from_map) {
166171  auto  buffer_idx = tensor_value->constant_buffer_idx ();
167172  if  (buffer_idx) {
168173    if  (!constant_data_ptr) {
@@ -171,10 +176,31 @@ const uint8_t* getConstantDataPtr(
171176      const  auto & constant_buffer = *flatbuffer_graph->constant_buffer ();
172177      return  constant_buffer[buffer_idx]->storage ()->data ();
173178    } else  {
174-       const  auto & constant_data_offsets = *flatbuffer_graph->constant_data ();
175-       uint64_t  constant_data_offset =
176-           constant_data_offsets[buffer_idx]->offset ();
177-       return  constant_data_ptr + constant_data_offset;
179+       ConstantDataOffsetPtr constant_data_offset =
180+           flatbuffer_graph->constant_data ()->Get (buffer_idx);
181+       uint64_t  offset = constant_data_offset->offset ();
182+ 
183+       bool  has_named_key = flatbuffers::IsFieldPresent (
184+           constant_data_offset, fb_xnnpack::ConstantDataOffset::VT_NAMED_KEY);
185+       //  If there is no tensor name
186+       if  (!has_named_key) {
187+         return  constant_data_ptr + offset;
188+       } else  {
189+         const  std::string& data_name = constant_data_offset->named_key ()->str ();
190+         Result<FreeableBuffer> buffer =
191+             named_data_map->get_data (data_name.c_str ());
192+         if  (!buffer.ok ()) {
193+           ET_LOG (
194+               Error,
195+               " Failed to get constant data for key %s"  ,
196+               data_name.c_str ());
197+           return  nullptr ;
198+         }
199+         const  uint8_t * data_ptr =
200+             static_cast <const  uint8_t *>(buffer.get ().data ());
201+         loaded_buffers_from_map.push_back (std::move (buffer.get ()));
202+         return  data_ptr;
203+       }
178204    }
179205  }
180206
@@ -194,7 +220,9 @@ Error defineTensor(
194220    const  uint8_t * constant_data_ptr,
195221    std::vector<uint32_t >& input_ids,
196222    std::vector<uint32_t >& output_ids,
197-     CompileAllocator& allocator) {
223+     CompileAllocator& allocator,
224+     const  NamedDataMap* named_data_map,
225+     std::vector<FreeableBuffer>& loaded_buffers_from_map) {
198226  const  fb_xnnpack::XNNTensorValue* tensor_value = nullptr ;
199227  const  fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr ;
200228
@@ -231,8 +259,12 @@ Error defineTensor(
231259
232260  //  Get Pointer to constant data from flatbuffer, if its non-constant
233261  //  it is a nullptr
234-   const  uint8_t * buffer_ptr =
235-       getConstantDataPtr (tensor_value, flatbuffer_graph, constant_data_ptr);
262+   const  uint8_t * buffer_ptr = getConstantDataPtr (
263+       tensor_value,
264+       flatbuffer_graph,
265+       constant_data_ptr,
266+       named_data_map,
267+       loaded_buffers_from_map);
236268
237269  xnn_status status;
238270  //  The type we might have to convert to
@@ -1968,6 +2000,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
19682000    size_t  num_bytes,
19692001    XNNExecutor* executor,
19702002    MemoryAllocator* runtime_allocator,
2003+     const  NamedDataMap* named_data_map,
19712004    xnn_workspace_t  workspace) {
19722005  Result<XNNHeader> header = XNNHeader::Parse (buffer_pointer, num_bytes);
19732006  const  uint8_t * flatbuffer_data = nullptr ;
@@ -2036,6 +2069,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20362069  std::vector<uint32_t > input_ids;
20372070  std::vector<uint32_t > output_ids;
20382071  Error err = Error::Ok;
2072+   std::vector<FreeableBuffer> loaded_buffers_from_map;
20392073  for  (auto  value : *flatbuffer_graph->xvalues ()) {
20402074    err = defineTensor (
20412075        subgraph.get (),
@@ -2045,7 +2079,9 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20452079        constant_data,
20462080        input_ids,
20472081        output_ids,
2048-         compile_allocator);
2082+         compile_allocator,
2083+         named_data_map,
2084+         loaded_buffers_from_map);
20492085
20502086    if  (err != Error::Ok) {
20512087      return  err;
0 commit comments