diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 21fd137b65b..b32f4eb4308 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -152,7 +152,7 @@ class GraphBuilder { VkGraphPtr flatbuffer_; const uint8_t* constant_data_; - std::unordered_map ref_mapping_; + std::vector ref_mapping_; public: explicit GraphBuilder( @@ -164,22 +164,20 @@ class GraphBuilder { constant_data_(constant_data), ref_mapping_() {} - bool fb_id_exists(const uint32_t fb_id) { - const std::unordered_map::iterator found_ref = - ref_mapping_.find(fb_id); + void resize(uint32_t size) { + ref_mapping_.resize(size, INT32_MAX); + } - return found_ref != ref_mapping_.end(); + bool fb_id_exists(const uint32_t fb_id) { + return fb_id < ref_mapping_.size() && ref_mapping_[fb_id] != INT32_MAX; } ValueRef get_fb_id_valueref(const uint32_t fb_id) { - const std::unordered_map::iterator found_ref = - ref_mapping_.find(fb_id); - ET_CHECK_MSG( - found_ref != ref_mapping_.end(), + fb_id_exists(fb_id), "Trying to extract a value that hasn't yet been added to the graph."); - return found_ref->second; + return ref_mapping_[fb_id]; } void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) { @@ -315,6 +313,9 @@ class GraphBuilder { } void build_graph() { + // Resize the mapping to the number of values in the flatbuffer + resize(flatbuffer_->values()->size()); + // First, add all values to the graph for (uint32_t fb_id = 0; fb_id < flatbuffer_->values()->size(); ++fb_id) { VkValuePtr value = flatbuffer_->values()->Get(fb_id); @@ -489,8 +490,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data); - GraphBuilder builder = - GraphBuilder(compute_graph, flatbuffer_graph, constant_data); + GraphBuilder builder(compute_graph, flatbuffer_graph, constant_data); builder.build_graph();