1010#include < executorch/backends/xnnpack/runtime/XNNHeader.h>
1111#include < executorch/backends/xnnpack/serialization/schema_generated.h>
1212#include < executorch/extension/threadpool/threadpool.h>
13- #include < executorch/runtime/core/exec_aten/util/scalar_type_util .h>
13+ #include < executorch/runtime/executor/pte_data_map .h>
1414#include < unordered_map>
1515
1616#pragma clang diagnostic ignored "-Wmissing-prototypes"
@@ -22,7 +22,9 @@ namespace xnnpack {
2222namespace delegate {
2323
2424using executorch::runtime::Error;
25+ using executorch::runtime::FreeableBuffer;
2526using executorch::runtime::MemoryAllocator;
27+ using executorch::runtime::NamedDataMap;
2628using executorch::runtime::Result;
2729
2830/*
@@ -48,6 +50,7 @@ class CompileAllocator {
4850using ValuePtr = const fb_xnnpack::XValue*;
4951using NodePtr = const fb_xnnpack::XNode*;
5052using GraphPtr = const fb_xnnpack::XNNGraph*;
53+ using ConstantDataOffsetPtr = const fb_xnnpack::ConstantDataOffset*;
5154using DataType = fb_xnnpack::XNNDatatype;
5255
5356// Type for define node function. This is the function signature
@@ -162,7 +165,9 @@ data associated with the tensor value, then returns nullptr.
162165const uint8_t * getConstantDataPtr (
163166 const fb_xnnpack::XNNTensorValue* tensor_value,
164167 GraphPtr flatbuffer_graph,
165- const uint8_t * constant_data_ptr) {
168+ const uint8_t * constant_data_ptr,
169+ const NamedDataMap* named_data_map,
170+ std::vector<FreeableBuffer>& loaded_buffers_from_map) {
166171 auto buffer_idx = tensor_value->constant_buffer_idx ();
167172 if (buffer_idx) {
168173 if (!constant_data_ptr) {
@@ -171,10 +176,31 @@ const uint8_t* getConstantDataPtr(
171176 const auto & constant_buffer = *flatbuffer_graph->constant_buffer ();
172177 return constant_buffer[buffer_idx]->storage ()->data ();
173178 } else {
174- const auto & constant_data_offsets = *flatbuffer_graph->constant_data ();
175- uint64_t constant_data_offset =
176- constant_data_offsets[buffer_idx]->offset ();
177- return constant_data_ptr + constant_data_offset;
179+ ConstantDataOffsetPtr constant_data_offset =
180+ flatbuffer_graph->constant_data ()->Get (buffer_idx);
181+ uint64_t offset = constant_data_offset->offset ();
182+
183+ bool has_named_key = flatbuffers::IsFieldPresent (
184+ constant_data_offset, fb_xnnpack::ConstantDataOffset::VT_NAMED_KEY);
185+ // If there is no tensor name
186+ if (!has_named_key) {
187+ return constant_data_ptr + offset;
188+ } else {
189+ const std::string& data_name = constant_data_offset->named_key ()->str ();
190+ Result<FreeableBuffer> buffer =
191+ named_data_map->get_data (data_name.c_str ());
192+ if (!buffer.ok ()) {
193+ ET_LOG (
194+ Error,
195+ " Failed to get constant data for key %s" ,
196+ data_name.c_str ());
197+ return nullptr ;
198+ }
199+ const uint8_t * data_ptr =
200+ static_cast <const uint8_t *>(buffer.get ().data ());
201+ loaded_buffers_from_map.push_back (std::move (buffer.get ()));
202+ return data_ptr;
203+ }
178204 }
179205 }
180206
@@ -194,7 +220,9 @@ Error defineTensor(
194220 const uint8_t * constant_data_ptr,
195221 std::vector<uint32_t >& input_ids,
196222 std::vector<uint32_t >& output_ids,
197- CompileAllocator& allocator) {
223+ CompileAllocator& allocator,
224+ const NamedDataMap* named_data_map,
225+ std::vector<FreeableBuffer>& loaded_buffers_from_map) {
198226 const fb_xnnpack::XNNTensorValue* tensor_value = nullptr ;
199227 const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr ;
200228
@@ -231,8 +259,12 @@ Error defineTensor(
231259
232260 // Get Pointer to constant data from flatbuffer, if its non-constant
233261 // it is a nullptr
234- const uint8_t * buffer_ptr =
235- getConstantDataPtr (tensor_value, flatbuffer_graph, constant_data_ptr);
262+ const uint8_t * buffer_ptr = getConstantDataPtr (
263+ tensor_value,
264+ flatbuffer_graph,
265+ constant_data_ptr,
266+ named_data_map,
267+ loaded_buffers_from_map);
236268
237269 xnn_status status;
238270 // The type we might have to convert to
@@ -1968,6 +2000,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
19682000 size_t num_bytes,
19692001 XNNExecutor* executor,
19702002 MemoryAllocator* runtime_allocator,
2003+ const NamedDataMap* named_data_map,
19712004 xnn_workspace_t workspace) {
19722005 Result<XNNHeader> header = XNNHeader::Parse (buffer_pointer, num_bytes);
19732006 const uint8_t * flatbuffer_data = nullptr ;
@@ -2036,6 +2069,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20362069 std::vector<uint32_t > input_ids;
20372070 std::vector<uint32_t > output_ids;
20382071 Error err = Error::Ok;
2072+ std::vector<FreeableBuffer> loaded_buffers_from_map;
20392073 for (auto value : *flatbuffer_graph->xvalues ()) {
20402074 err = defineTensor (
20412075 subgraph.get (),
@@ -2045,7 +2079,9 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20452079 constant_data,
20462080 input_ids,
20472081 output_ids,
2048- compile_allocator);
2082+ compile_allocator,
2083+ named_data_map,
2084+ loaded_buffers_from_map);
20492085
20502086 if (err != Error::Ok) {
20512087 return err;
0 commit comments