1010#include < executorch/backends/xnnpack/runtime/XNNHeader.h>
1111#include < executorch/backends/xnnpack/serialization/schema_generated.h>
1212#include < executorch/extension/threadpool/threadpool.h>
13- #include < executorch/runtime/core/exec_aten/util/scalar_type_util .h>
13+ #include < executorch/runtime/executor/pte_data_map .h>
1414#include < unordered_map>
1515
1616#pragma clang diagnostic ignored "-Wmissing-prototypes"
@@ -24,6 +24,8 @@ namespace delegate {
2424using executorch::runtime::Error;
2525using executorch::runtime::MemoryAllocator;
2626using executorch::runtime::Result;
27+ using executorch::runtime::FreeableBuffer;
28+ using executorch::runtime::NamedDataMap;
2729
2830/*
2931 * Provide compile-time allocation.
@@ -48,6 +50,7 @@ class CompileAllocator {
4850using ValuePtr = const fb_xnnpack::XValue*;
4951using NodePtr = const fb_xnnpack::XNode*;
5052using GraphPtr = const fb_xnnpack::XNNGraph*;
53+ using ConstantDataOffsetPtr = const fb_xnnpack::ConstantDataOffset*;
5154using DataType = fb_xnnpack::XNNDatatype;
5255
5356// Type for define node function. This is the function signature
@@ -162,7 +165,9 @@ data associated with the tensor value, then returns nullptr.
162165const uint8_t * getConstantDataPtr (
163166 const fb_xnnpack::XNNTensorValue* tensor_value,
164167 GraphPtr flatbuffer_graph,
165- const uint8_t * constant_data_ptr) {
168+ const uint8_t * constant_data_ptr,
169+ const NamedDataMap* named_data_map,
170+ std::vector<FreeableBuffer>& loaded_buffers_from_map) {
166171 auto buffer_idx = tensor_value->constant_buffer_idx ();
167172 if (buffer_idx) {
168173 if (!constant_data_ptr) {
@@ -171,10 +176,23 @@ const uint8_t* getConstantDataPtr(
171176 const auto & constant_buffer = *flatbuffer_graph->constant_buffer ();
172177 return constant_buffer[buffer_idx]->storage ()->data ();
173178 } else {
174- const auto & constant_data_offsets = *flatbuffer_graph->constant_data ();
175- uint64_t constant_data_offset =
176- constant_data_offsets[buffer_idx]->offset ();
177- return constant_data_ptr + constant_data_offset;
179+ ConstantDataOffsetPtr constant_data_offset = flatbuffer_graph->constant_data ()->Get (buffer_idx);
180+ uint64_t offset = constant_data_offset->offset ();
181+
182+ const std::string &data_name = constant_data_offset->named_key ()->str ();
183+ // If there is no tensor name
184+ if (data_name.length () == 0 ) {
185+ return constant_data_ptr + offset;
186+ } else {
187+ Result<FreeableBuffer> buffer = named_data_map->get_data (data_name.c_str ());
188+ if (!buffer.ok ()) {
189+ ET_LOG (Error, " Failed to get constant data for key %s" , data_name.c_str ());
190+ return nullptr ;
191+ }
192+ const uint8_t * data_ptr = static_cast <const uint8_t *>(buffer.get ().data ());
193+ loaded_buffers_from_map.push_back (std::move (buffer.get ()));
194+ return data_ptr;
195+ }
178196 }
179197 }
180198
@@ -194,7 +212,9 @@ Error defineTensor(
194212 const uint8_t * constant_data_ptr,
195213 std::vector<uint32_t >& input_ids,
196214 std::vector<uint32_t >& output_ids,
197- CompileAllocator& allocator) {
215+ CompileAllocator& allocator,
216+ const NamedDataMap* named_data_map,
217+ std::vector<FreeableBuffer>& loaded_buffers_from_map) {
198218 const fb_xnnpack::XNNTensorValue* tensor_value = nullptr ;
199219 const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr ;
200220
@@ -231,8 +251,13 @@ Error defineTensor(
231251
232252 // Get Pointer to constant data from flatbuffer, if its non-constant
233253 // it is a nullptr
234- const uint8_t * buffer_ptr =
235- getConstantDataPtr (tensor_value, flatbuffer_graph, constant_data_ptr);
254+ const uint8_t * buffer_ptr = getConstantDataPtr (
255+ tensor_value,
256+ flatbuffer_graph,
257+ constant_data_ptr,
258+ named_data_map,
259+ loaded_buffers_from_map
260+ );
236261
237262 xnn_status status;
238263 // The type we might have to convert to
@@ -1968,6 +1993,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
19681993 size_t num_bytes,
19691994 XNNExecutor* executor,
19701995 MemoryAllocator* runtime_allocator,
1996+ const NamedDataMap* named_data_map,
19711997 xnn_workspace_t workspace) {
19721998 Result<XNNHeader> header = XNNHeader::Parse (buffer_pointer, num_bytes);
19731999 const uint8_t * flatbuffer_data = nullptr ;
@@ -2036,6 +2062,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20362062 std::vector<uint32_t > input_ids;
20372063 std::vector<uint32_t > output_ids;
20382064 Error err = Error::Ok;
2065+ std::vector<FreeableBuffer> loaded_buffers_from_map;
20392066 for (auto value : *flatbuffer_graph->xvalues ()) {
20402067 err = defineTensor (
20412068 subgraph.get (),
@@ -2045,7 +2072,9 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20452072 constant_data,
20462073 input_ids,
20472074 output_ids,
2048- compile_allocator);
2075+ compile_allocator,
2076+ named_data_map,
2077+ loaded_buffers_from_map);
20492078
20502079 if (err != Error::Ok) {
20512080 return err;
0 commit comments