2222#include < executorch/runtime/core/event_tracer_hooks_delegate.h>
2323#endif // ET_EVENT_TRACER_ENABLED
2424#include < executorch/runtime/core/exec_aten/util/tensor_util.h>
25+ #include < executorch/runtime/executor/pte_data_map.h>
2526#include < executorch/runtime/platform/compiler.h>
2627#include < executorch/runtime/platform/profiler.h>
2728
@@ -47,6 +48,7 @@ using executorch::runtime::Error;
4748using executorch::runtime::EValue;
4849using executorch::runtime::FreeableBuffer;
4950using executorch::runtime::kTensorDimensionLimit ;
51+ using executorch::runtime::NamedDataMap;
5052using executorch::runtime::Result;
5153using executorch::runtime::Span;
5254
@@ -69,8 +71,29 @@ using UIntVector = const flatbuffers::Vector<uint32_t>*;
6971const uint8_t * get_constant_data_ptr (
7072 VkGraphPtr flatbuffer_graph,
7173 const int32_t buffer_idx,
72- const uint8_t * constant_data) {
74+ const uint8_t * constant_data,
75+ const NamedDataMap* named_data_map,
76+ std::vector<FreeableBuffer>& loaded_buffers_from_map) {
7377 VkBytesPtr constant_bytes = flatbuffer_graph->constants ()->Get (buffer_idx);
78+
79+ // Check if there's a named key for this constant data
80+ if (constant_bytes->named_key () != nullptr && named_data_map != nullptr ) {
81+ const std::string& data_name = constant_bytes->named_key ()->str ();
82+ Result<FreeableBuffer> buffer = named_data_map->get_data (data_name.c_str ());
83+ if (!buffer.ok ()) {
84+ ET_LOG (
85+ Error,
86+ " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
87+ data_name.c_str (),
88+ static_cast <uint32_t >(buffer.error ()));
89+ return nullptr ;
90+ }
91+ const uint8_t * data_ptr = static_cast <const uint8_t *>(buffer.get ().data ());
92+ loaded_buffers_from_map.push_back (std::move (buffer.get ()));
93+ return data_ptr;
94+ }
95+
96+ // Fallback to offset-based access
7497 return constant_data + constant_bytes->offset ();
7598}
7699
@@ -166,17 +189,22 @@ class GraphBuilder {
166189 ComputeGraph* compute_graph_;
167190 VkGraphPtr flatbuffer_;
168191 const uint8_t * constant_data_;
192+ const NamedDataMap* named_data_map_;
193+ std::vector<FreeableBuffer> loaded_buffers_from_map_;
169194
170195 std::vector<ValueRef> ref_mapping_;
171196
172197 public:
173198 explicit GraphBuilder (
174199 ComputeGraph* compute_graph,
175200 VkGraphPtr flatbuffer,
176- const uint8_t * constant_data)
201+ const uint8_t * constant_data,
202+ const NamedDataMap* named_data_map)
177203 : compute_graph_(compute_graph),
178204 flatbuffer_(flatbuffer),
179205 constant_data_(constant_data),
206+ named_data_map_(named_data_map),
207+ loaded_buffers_from_map_(),
180208 ref_mapping_() {}
181209
182210 void resize (uint32_t size) {
@@ -212,10 +240,27 @@ class GraphBuilder {
212240
213241 ValueRef ref;
214242 if (tensor_fb->constant_id () >= 0 ) {
215- const uint8_t * tensor_data = get_constant_data_ptr (
216- flatbuffer_, tensor_fb->constant_id (), constant_data_);
243+ VkBytesPtr constant_bytes =
244+ flatbuffer_->constants ()->Get (tensor_fb->constant_id ());
245+
246+ if (constant_bytes->named_key () != nullptr &&
247+ constant_bytes->offset () == UINT64_MAX &&
248+ named_data_map_ != nullptr ) {
249+ const std::string& data_name = constant_bytes->named_key ()->str ();
250+ Result<FreeableBuffer> buffer =
251+ named_data_map_->get_data (data_name.c_str ());
217252
218- ref = compute_graph_->add_tensorref (dims_vector, dtype, tensor_data);
253+ VK_CHECK_COND (
254+ buffer.ok (),
255+ " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
256+ data_name.c_str (),
257+ static_cast <uint32_t >(buffer.error ()));
258+ ref = compute_graph_->add_tensorref (
259+ dims_vector, dtype, std::move (buffer.get ()));
260+ } else {
261+ const uint8_t * tensor_data = constant_data_ + constant_bytes->offset ();
262+ ref = compute_graph_->add_tensorref (dims_vector, dtype, tensor_data);
263+ }
219264 } else {
220265 ref = compute_graph_->add_tensor (
221266 dims_vector,
@@ -479,8 +524,10 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
479524 return true ;
480525 }
481526
482- ET_NODISCARD Error
483- compileModel (const void * buffer_pointer, ComputeGraph* compute_graph) const {
527+ ET_NODISCARD Error compileModel (
528+ const void * buffer_pointer,
529+ ComputeGraph* compute_graph,
530+ const NamedDataMap* named_data_map) const {
484531 Result<VulkanDelegateHeader> header =
485532 VulkanDelegateHeader::parse (buffer_pointer);
486533
@@ -506,7 +553,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
506553
507554 VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph (flatbuffer_data);
508555
509- GraphBuilder builder (compute_graph, flatbuffer_graph, constant_data);
556+ GraphBuilder builder (
557+ compute_graph, flatbuffer_graph, constant_data, named_data_map);
510558
511559 builder.build_graph ();
512560
@@ -532,7 +580,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
532580 graph_config.external_adapter = vkapi::set_and_get_external_adapter ();
533581 new (compute_graph) ComputeGraph (graph_config);
534582
535- Error err = compileModel (processed->data (), compute_graph);
583+ const NamedDataMap* named_data_map = context.get_named_data_map ();
584+ Error err = compileModel (processed->data (), compute_graph, named_data_map);
536585
537586 // This backend does not need its processed data after compiling the
538587 // model.
0 commit comments