1111#include < executorch/backends/xnnpack/serialization/schema_generated.h>
1212#include < executorch/extension/threadpool/threadpool.h>
1313#include < executorch/runtime/executor/pte_data_map.h>
14+ #include < string>
1415#include < unordered_map>
16+ #include < vector>
1517
1618#pragma clang diagnostic ignored "-Wmissing-prototypes"
1719#pragma clang diagnostic ignored "-Wglobal-constructors"
@@ -167,7 +169,8 @@ const uint8_t* getConstantDataPtr(
167169 GraphPtr flatbuffer_graph,
168170 const uint8_t * constant_data_ptr,
169171 const NamedDataMap* named_data_map,
170- std::vector<FreeableBuffer>& loaded_buffers_from_map) {
172+ std::vector<FreeableBuffer>& freeable_buffers,
173+ XNNWeightsCache* weights_cache) {
171174 auto buffer_idx = tensor_value->constant_buffer_idx ();
172175 if (buffer_idx) {
173176 if (!constant_data_ptr) {
@@ -187,6 +190,15 @@ const uint8_t* getConstantDataPtr(
187190 return constant_data_ptr + offset;
188191 } else {
189192 const std::string& data_name = constant_data_offset->named_key ()->str ();
193+ #ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
194+ Result<const uint8_t *> data_ptr =
195+ weights_cache->load_unpacked_data (data_name);
196+ if (!data_ptr.ok ()) {
197+ ET_LOG (Error, " Failed to load weights from cache" );
198+ return nullptr ;
199+ }
200+ return data_ptr.get ();
201+ #else
190202 Result<FreeableBuffer> buffer =
191203 named_data_map->get_data (data_name.c_str ());
192204 if (!buffer.ok ()) {
@@ -198,8 +210,9 @@ const uint8_t* getConstantDataPtr(
198210 }
199211 const uint8_t * data_ptr =
200212 static_cast <const uint8_t *>(buffer.get ().data ());
201- loaded_buffers_from_map .push_back (std::move (buffer.get ()));
213+ freeable_buffers .push_back (std::move (buffer.get ()));
202214 return data_ptr;
215+ #endif
203216 }
204217 }
205218 }
@@ -222,7 +235,8 @@ Error defineTensor(
222235 std::vector<uint32_t >& output_ids,
223236 CompileAllocator& allocator,
224237 const NamedDataMap* named_data_map,
225- std::vector<FreeableBuffer>& loaded_buffers_from_map) {
238+ std::vector<FreeableBuffer>& freeable_buffers,
239+ XNNWeightsCache* weights_cache) {
226240 const fb_xnnpack::XNNTensorValue* tensor_value = nullptr ;
227241 const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr ;
228242
@@ -264,7 +278,8 @@ Error defineTensor(
264278 flatbuffer_graph,
265279 constant_data_ptr,
266280 named_data_map,
267- loaded_buffers_from_map);
281+ freeable_buffers,
282+ weights_cache);
268283
269284 xnn_status status;
270285 // The type we might have to convert to
@@ -1999,9 +2014,9 @@ ET_NODISCARD Error XNNCompiler::compileModel(
19992014 const void * buffer_pointer,
20002015 size_t num_bytes,
20012016 XNNExecutor* executor,
2002- MemoryAllocator* runtime_allocator ,
2003- const NamedDataMap* named_data_map ,
2004- xnn_workspace_t workspace ) {
2017+ XNNWeightsCache* weights_cache ,
2018+ xnn_workspace_t workspace ,
2019+ const NamedDataMap* named_data_map ) {
20052020 Result<XNNHeader> header = XNNHeader::Parse (buffer_pointer, num_bytes);
20062021 const uint8_t * flatbuffer_data = nullptr ;
20072022 const uint8_t * constant_data = nullptr ;
@@ -2065,11 +2080,14 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20652080 // Invalid ids do not need to be remapped
20662081 remapped_ids.emplace (XNN_INVALID_VALUE_ID, XNN_INVALID_VALUE_ID);
20672082
2083+ // If weight cache is not on we hold onto all the unpacked buffers
2084+ // and we free them at the end
2085+ std::vector<FreeableBuffer> unpacked_buffers;
2086+
20682087 // External Ids for inputs and outputs
20692088 std::vector<uint32_t > input_ids;
20702089 std::vector<uint32_t > output_ids;
20712090 Error err = Error::Ok;
2072- std::vector<FreeableBuffer> loaded_buffers_from_map;
20732091 for (auto value : *flatbuffer_graph->xvalues ()) {
20742092 err = defineTensor (
20752093 subgraph.get (),
@@ -2081,7 +2099,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
20812099 output_ids,
20822100 compile_allocator,
20832101 named_data_map,
2084- loaded_buffers_from_map);
2102+ unpacked_buffers,
2103+ weights_cache);
20852104
20862105 if (err != Error::Ok) {
20872106 return err;
@@ -2103,20 +2122,34 @@ ET_NODISCARD Error XNNCompiler::compileModel(
21032122
21042123 xnn_runtime_t runtime_ptr = nullptr ;
21052124
2125+ // XNNWeightsCache if weights cache is not enabled, then XNNWeightsCache
2126+ // just manages the unpacked weights until the runtime is created.
2127+ #ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
2128+ ET_CHECK_OR_RETURN_ERROR (
2129+ unpacked_buffers.size () == 0 ,
2130+ Internal,
2131+ " Weight Cache is enabled, which means unpacked buffers should be owned by the cache" );
2132+ xnn_weights_cache_t weights_cache_ptr =
2133+ weights_cache->get_num_unpacked_data () > 0 ? weights_cache->get ()
2134+ : nullptr ;
2135+ #else
2136+ xnn_weights_cache_t weights_cache_ptr = nullptr ;
2137+ #endif
2138+
21062139#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
21072140 ET_CHECK_OR_RETURN_ERROR (
21082141 workspace != nullptr , Internal, " Failed to initialize XNNPACK workspace" );
21092142 status = xnn_create_runtime_v4 (
21102143 subgraph.get (),
2111- /* weight_cache= */ nullptr , // TODO - support weight cache
2144+ weights_cache_ptr,
21122145 workspace,
21132146 ::executorch::extension::threadpool::get_pthreadpool (),
21142147 runtime_flags,
21152148 &runtime_ptr);
21162149#else
21172150 status = xnn_create_runtime_v3 (
21182151 subgraph.get (),
2119- /* weight_cache= */ nullptr , // TODO - support weight cache
2152+ weights_cache_ptr,
21202153 ::executorch::extension::threadpool::get_pthreadpool (),
21212154 runtime_flags,
21222155 &runtime_ptr);
@@ -2128,10 +2161,25 @@ ET_NODISCARD Error XNNCompiler::compileModel(
21282161 " XNN Runtime creation failed with code: %s" ,
21292162 xnn_status_to_string (status));
21302163
2164+ #ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
2165+ auto packed_weights_names = weights_cache->finalize_for_runtime ();
2166+ ET_CHECK_OR_RETURN_ERROR (
2167+ packed_weights_names.ok (),
2168+ Internal,
2169+ " Failed to finalize weights cache after creating the xnn runtime" )
2170+ #else
2171+ for (auto & buffer : unpacked_buffers) {
2172+ buffer.Free ();
2173+ }
2174+ Result<std::vector<std::string>> packed_weights_names =
2175+ std::vector<std::string>();
2176+ #endif
2177+
21312178 err = executor->initialize ( // NOLINT: runtime_ptr is non-null
21322179 runtime_ptr,
21332180 std::move (input_ids),
2134- std::move (output_ids));
2181+ std::move (output_ids),
2182+ std::move (packed_weights_names.get ()));
21352183
21362184 return err;
21372185};
0 commit comments