Skip to content

Commit c5cfe62

Browse files
committed
[ExecuTorch][Weight Sharing][XNNPACK] load named data map data for xnnpack
If data is serialized into the NamedDataMap, then we overload getConstantDataPtr to retrieve the data from the named data map. This should be done in a Backwards Compatible way. Meaning if no data is serialized into the named data map, then we are still loading the data from the flatbuffer payload. Since the runtime change here is being made before the AoT changes, All CI on this diff by itself should test that the changes made here are backwards compatitble. Note: We do not resolve Runtime Memory usage at this point. WeightCache will be implemented in the next diff. Meaning If we load via the same key across different methods, we still pack twice and allocate two instances for the packed weights. Differential Revision: [D70315209](https://our.internmc.facebook.com/intern/diff/D70315209/) [ghstack-poisoned]
1 parent 23a3efd commit c5cfe62

File tree

5 files changed

+55
-12
lines changed

5 files changed

+55
-12
lines changed

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <executorch/backends/xnnpack/runtime/XNNHeader.h>
1111
#include <executorch/backends/xnnpack/serialization/schema_generated.h>
1212
#include <executorch/extension/threadpool/threadpool.h>
13-
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
13+
#include <executorch/runtime/executor/pte_data_map.h>
1414
#include <unordered_map>
1515

1616
#pragma clang diagnostic ignored "-Wmissing-prototypes"
@@ -24,6 +24,8 @@ namespace delegate {
2424
using executorch::runtime::Error;
2525
using executorch::runtime::MemoryAllocator;
2626
using executorch::runtime::Result;
27+
using executorch::runtime::FreeableBuffer;
28+
using executorch::runtime::NamedDataMap;
2729

2830
/*
2931
* Provide compile-time allocation.
@@ -48,6 +50,7 @@ class CompileAllocator {
4850
using ValuePtr = const fb_xnnpack::XValue*;
4951
using NodePtr = const fb_xnnpack::XNode*;
5052
using GraphPtr = const fb_xnnpack::XNNGraph*;
53+
using ConstantDataOffsetPtr = const fb_xnnpack::ConstantDataOffset*;
5154
using DataType = fb_xnnpack::XNNDatatype;
5255

5356
// Type for define node function. This is the function signature
@@ -162,7 +165,9 @@ data associated with the tensor value, then returns nullptr.
162165
const uint8_t* getConstantDataPtr(
163166
const fb_xnnpack::XNNTensorValue* tensor_value,
164167
GraphPtr flatbuffer_graph,
165-
const uint8_t* constant_data_ptr) {
168+
const uint8_t* constant_data_ptr,
169+
const NamedDataMap* named_data_map,
170+
std::vector<FreeableBuffer>& loaded_buffers_from_map) {
166171
auto buffer_idx = tensor_value->constant_buffer_idx();
167172
if (buffer_idx) {
168173
if (!constant_data_ptr) {
@@ -171,10 +176,23 @@ const uint8_t* getConstantDataPtr(
171176
const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
172177
return constant_buffer[buffer_idx]->storage()->data();
173178
} else {
174-
const auto& constant_data_offsets = *flatbuffer_graph->constant_data();
175-
uint64_t constant_data_offset =
176-
constant_data_offsets[buffer_idx]->offset();
177-
return constant_data_ptr + constant_data_offset;
179+
ConstantDataOffsetPtr constant_data_offset = flatbuffer_graph->constant_data()->Get(buffer_idx);
180+
uint64_t offset = constant_data_offset->offset();
181+
182+
const std::string &data_name = constant_data_offset->named_key()->str();
183+
// If there is no tensor name
184+
if (data_name.length() == 0) {
185+
return constant_data_ptr + offset;
186+
} else {
187+
Result<FreeableBuffer> buffer = named_data_map->get_data(data_name.c_str());
188+
if (!buffer.ok()) {
189+
ET_LOG(Error, "Failed to get constant data for key %s", data_name.c_str());
190+
return nullptr;
191+
}
192+
const uint8_t* data_ptr = static_cast<const uint8_t*>(buffer.get().data());
193+
loaded_buffers_from_map.push_back(std::move(buffer.get()));
194+
return data_ptr;
195+
}
178196
}
179197
}
180198

@@ -194,7 +212,9 @@ Error defineTensor(
194212
const uint8_t* constant_data_ptr,
195213
std::vector<uint32_t>& input_ids,
196214
std::vector<uint32_t>& output_ids,
197-
CompileAllocator& allocator) {
215+
CompileAllocator& allocator,
216+
const NamedDataMap* named_data_map,
217+
std::vector<FreeableBuffer>& loaded_buffers_from_map) {
198218
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
199219
const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;
200220

@@ -231,8 +251,13 @@ Error defineTensor(
231251

232252
// Get Pointer to constant data from flatbuffer, if its non-constant
233253
// it is a nullptr
234-
const uint8_t* buffer_ptr =
235-
getConstantDataPtr(tensor_value, flatbuffer_graph, constant_data_ptr);
254+
const uint8_t* buffer_ptr = getConstantDataPtr(
255+
tensor_value,
256+
flatbuffer_graph,
257+
constant_data_ptr,
258+
named_data_map,
259+
loaded_buffers_from_map
260+
);
236261

237262
xnn_status status;
238263
// The type we might have to convert to
@@ -1968,6 +1993,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
19681993
size_t num_bytes,
19691994
XNNExecutor* executor,
19701995
MemoryAllocator* runtime_allocator,
1996+
const NamedDataMap* named_data_map,
19711997
xnn_workspace_t workspace) {
19721998
Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
19731999
const uint8_t* flatbuffer_data = nullptr;
@@ -2036,6 +2062,7 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20362062
std::vector<uint32_t> input_ids;
20372063
std::vector<uint32_t> output_ids;
20382064
Error err = Error::Ok;
2065+
std::vector<FreeableBuffer> loaded_buffers_from_map;
20392066
for (auto value : *flatbuffer_graph->xvalues()) {
20402067
err = defineTensor(
20412068
subgraph.get(),
@@ -2045,7 +2072,9 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20452072
constant_data,
20462073
input_ids,
20472074
output_ids,
2048-
compile_allocator);
2075+
compile_allocator,
2076+
named_data_map,
2077+
loaded_buffers_from_map);
20492078

20502079
if (err != Error::Ok) {
20512080
return err;

backends/xnnpack/runtime/XNNCompiler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class XNNCompiler {
3030
size_t num_bytes,
3131
XNNExecutor* executor,
3232
executorch::runtime::MemoryAllocator* runtime_allocator,
33+
const executorch::runtime::NamedDataMap* named_data_map,
3334
xnn_workspace_t workspace);
3435
};
3536

backends/xnnpack/runtime/XNNPACKBackend.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <executorch/runtime/backend/interface.h>
1111
#include <executorch/runtime/core/error.h>
1212
#include <executorch/runtime/core/evalue.h>
13-
#include <executorch/runtime/platform/profiler.h>
13+
#include <executorch/runtime/executor/pte_data_map.h>
1414

1515
#include <memory>
1616
#include <mutex>
@@ -30,6 +30,7 @@ using executorch::runtime::Error;
3030
using executorch::runtime::EValue;
3131
using executorch::runtime::FreeableBuffer;
3232
using executorch::runtime::Result;
33+
using executorch::runtime::NamedDataMap;
3334

3435
class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
3536
public:
@@ -79,13 +80,14 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
7980
return Error::MemoryAllocationFailed;
8081
}
8182

83+
const NamedDataMap* named_data_map = context.get_named_data_map();
84+
8285
#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
8386
// This is needed to serialize access to xnn_create_runtime which is not
8487
// thread safe. This can heppen when multiple threads call init() on
8588
// the same backend instance.
8689
const std::lock_guard<std::mutex> lock(workspace_mutex_);
8790
#endif
88-
8991
// Executor has been allocated but not constructed, ensure that runtime_ is
9092
// nullptr by constructing it in place here. NOTE: Since we use placement
9193
// new and since this type is not trivially destructible, we must call the
@@ -96,6 +98,7 @@ class XnnpackBackend final : public ::executorch::runtime::BackendInterface {
9698
processed->size(),
9799
executor,
98100
context.get_runtime_allocator(),
101+
named_data_map,
99102
workspace_.get());
100103
// This backend does not need its processed data after compiling the model.
101104
processed->Free();

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,11 +320,20 @@ table XNNLeakyReLU {
320320
table ConstantDataOffset {
321321
// Constant data offsets are relative to the constant data base offset provided
322322
// in the XNNPACKHeader.
323+
// named_key and offset are mutually exclusive, meaning only one of these values
324+
// are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX.
325+
// If the offset is not UINT64_MAX, then the named key must be an empty string
323326
offset: uint64;
324327

325328
// The size in bytes of valid data starting at the offset. The constant data
326329
// may be followed by padding before the next piece of constant data
327330
size: uint64;
331+
332+
// unique string id used to query the offset from the named data store.
333+
// named_key and offset are mutually exclusive, meaning only one of these values
334+
// are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX.
335+
// If the offset is not UINT64_MAX, then the named key must be an empty string
336+
named_key: string;
328337
}
329338

330339
table XNNGraph {

backends/xnnpack/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def define_common_targets():
6060
"//executorch/backends/xnnpack/serialization:xnnpack_flatbuffer_header",
6161
"//executorch/extension/threadpool:threadpool",
6262
"//executorch/runtime/core/exec_aten/util:tensor_util",
63+
"//executorch/runtime/executor:pte_data_map"
6364
],
6465
# XnnpackBackend.cpp needs to compile with executor as whole
6566
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)

0 commit comments

Comments
 (0)