22
22
#include < executorch/runtime/core/event_tracer_hooks_delegate.h>
23
23
#endif // ET_EVENT_TRACER_ENABLED
24
24
#include < executorch/runtime/core/exec_aten/util/tensor_util.h>
25
+ #include < executorch/runtime/core/named_data_map.h>
25
26
#include < executorch/runtime/platform/compiler.h>
26
27
#include < executorch/runtime/platform/profiler.h>
27
28
@@ -47,6 +48,7 @@ using executorch::runtime::Error;
47
48
using executorch::runtime::EValue;
48
49
using executorch::runtime::FreeableBuffer;
49
50
using executorch::runtime::kTensorDimensionLimit ;
51
+ using executorch::runtime::NamedDataMap;
50
52
using executorch::runtime::Result;
51
53
using executorch::runtime::Span;
52
54
@@ -66,14 +68,6 @@ using BytesVector =
66
68
const flatbuffers::Vector<flatbuffers::Offset<vkgraph::VkBytes>>*;
67
69
using UIntVector = const flatbuffers::Vector<uint32_t >*;
68
70
69
- const uint8_t * get_constant_data_ptr (
70
- VkGraphPtr flatbuffer_graph,
71
- const int32_t buffer_idx,
72
- const uint8_t * constant_data) {
73
- VkBytesPtr constant_bytes = flatbuffer_graph->constants ()->Get (buffer_idx);
74
- return constant_data + constant_bytes->offset ();
75
- }
76
-
77
71
vkapi::ScalarType get_scalar_type (const vkgraph::VkDataType& vk_datatype) {
78
72
switch (vk_datatype) {
79
73
case vkgraph::VkDataType::BOOL:
@@ -166,17 +160,22 @@ class GraphBuilder {
166
160
ComputeGraph* compute_graph_;
167
161
VkGraphPtr flatbuffer_;
168
162
const uint8_t * constant_data_;
163
+ const NamedDataMap* named_data_map_;
164
+ std::vector<FreeableBuffer> loaded_buffers_from_map_;
169
165
170
166
std::vector<ValueRef> ref_mapping_;
171
167
172
168
public:
173
169
explicit GraphBuilder (
174
170
ComputeGraph* compute_graph,
175
171
VkGraphPtr flatbuffer,
176
- const uint8_t * constant_data)
172
+ const uint8_t * constant_data,
173
+ const NamedDataMap* named_data_map)
177
174
: compute_graph_(compute_graph),
178
175
flatbuffer_(flatbuffer),
179
176
constant_data_(constant_data),
177
+ named_data_map_(named_data_map),
178
+ loaded_buffers_from_map_(),
180
179
ref_mapping_() {}
181
180
182
181
void resize (uint32_t size) {
@@ -212,10 +211,27 @@ class GraphBuilder {
212
211
213
212
ValueRef ref;
214
213
if (tensor_fb->constant_id () >= 0 ) {
215
- const uint8_t * tensor_data = get_constant_data_ptr (
216
- flatbuffer_, tensor_fb->constant_id (), constant_data_ );
214
+ VkBytesPtr constant_bytes =
215
+ flatbuffer_-> constants ()-> Get ( tensor_fb->constant_id ());
217
216
218
- ref = compute_graph_->add_tensorref (dims_vector, dtype, tensor_data);
217
+ if (constant_bytes->named_key () != nullptr &&
218
+ constant_bytes->offset () == UINT64_MAX &&
219
+ named_data_map_ != nullptr ) {
220
+ const std::string& data_name = constant_bytes->named_key ()->str ();
221
+ Result<FreeableBuffer> buffer =
222
+ named_data_map_->get_data (data_name.c_str ());
223
+
224
+ VK_CHECK_COND (
225
+ buffer.ok (),
226
+ " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
227
+ data_name.c_str (),
228
+ static_cast <uint32_t >(buffer.error ()));
229
+ ref = compute_graph_->add_tensorref (
230
+ dims_vector, dtype, std::move (buffer.get ()));
231
+ } else {
232
+ const uint8_t * tensor_data = constant_data_ + constant_bytes->offset ();
233
+ ref = compute_graph_->add_tensorref (dims_vector, dtype, tensor_data);
234
+ }
219
235
} else {
220
236
ref = compute_graph_->add_tensor (
221
237
dims_vector,
@@ -479,8 +495,10 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
479
495
return true ;
480
496
}
481
497
482
- ET_NODISCARD Error
483
- compileModel (const void * buffer_pointer, ComputeGraph* compute_graph) const {
498
+ ET_NODISCARD Error compileModel (
499
+ const void * buffer_pointer,
500
+ ComputeGraph* compute_graph,
501
+ const NamedDataMap* named_data_map) const {
484
502
Result<VulkanDelegateHeader> header =
485
503
VulkanDelegateHeader::parse (buffer_pointer);
486
504
@@ -506,7 +524,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
506
524
507
525
VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph (flatbuffer_data);
508
526
509
- GraphBuilder builder (compute_graph, flatbuffer_graph, constant_data);
527
+ GraphBuilder builder (
528
+ compute_graph, flatbuffer_graph, constant_data, named_data_map);
510
529
511
530
builder.build_graph ();
512
531
@@ -532,7 +551,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
532
551
graph_config.external_adapter = vkapi::set_and_get_external_adapter ();
533
552
new (compute_graph) ComputeGraph (graph_config);
534
553
535
- Error err = compileModel (processed->data (), compute_graph);
554
+ const NamedDataMap* named_data_map = context.get_named_data_map ();
555
+ Error err = compileModel (processed->data (), compute_graph, named_data_map);
536
556
537
557
// This backend does not need its processed data after compiling the
538
558
// model.
0 commit comments