From 8ebea8bf8d7d0da710129c7064797725cad29ae5 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 18 Aug 2025 10:28:10 -0700 Subject: [PATCH 1/4] [ET-VK][ez] Move execute node threshold calculation from `prepare_pipelines()` to `prepare()` Title says it all; `prepare()` is a more appropriate place for this action than `prepare_pipelines()`. ## Motivation Fix potential floating point exception (divide-by-zero) during tests. Some tests don't call `prepare_pipelines()`, which means `execute_threshold_node_count_` is unititialized, causing a divide by zero in execute when trying to modulo with `execute_threshold_node_count_` Differential Revision: [D80468138](https://our.internmc.facebook.com/intern/diff/D80468138/) ghstack-source-id: 303779589 Pull Request resolved: https://github.com/pytorch/executorch/pull/13478 --- .../vulkan/runtime/graph/ComputeGraph.cpp | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index acd20c9ee44..33bfe8e3675 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -813,25 +813,8 @@ void ComputeGraph::prepare() { context_->initialize_querypool(); } - for (SharedObject& shared_object : shared_objects_) { - shared_object.allocate(this); - shared_object.bind_users(this); - } -} - -void ComputeGraph::prepare_pipelines() { - for (std::unique_ptr& node : prepack_nodes_) { - node->prepare_pipelines(this); - } - for (std::unique_ptr& node : execute_nodes_) { - node->prepare_pipelines(this); - } - context_->pipeline_cache().create_pipelines(pipeline_descriptors_); - - pipeline_descriptors_ = std::unordered_set< - vkapi::ComputePipelineCache::Key, - vkapi::ComputePipelineCache::Hasher>(); - + // Calculate the threshold at which a new command buffer should be created + // during execute() const size_t total_node_count = execute_nodes_.size(); size_t init_threshold = config_.execute_initial_threshold_node_count; size_t count_threshold = config_.execute_threshold_node_count; @@ -858,6 +841,25 @@ void ComputeGraph::prepare_pipelines() { } execute_threshold_node_count_ = count_threshold; + + for (SharedObject& shared_object : shared_objects_) { + shared_object.allocate(this); + shared_object.bind_users(this); + } +} + +void ComputeGraph::prepare_pipelines() { + for (std::unique_ptr& node : prepack_nodes_) { + node->prepare_pipelines(this); + } + for (std::unique_ptr& node : execute_nodes_) { + node->prepare_pipelines(this); + } + context_->pipeline_cache().create_pipelines(pipeline_descriptors_); + + pipeline_descriptors_ = std::unordered_set< + vkapi::ComputePipelineCache::Key, + vkapi::ComputePipelineCache::Hasher>(); } void ComputeGraph::submit_current_cmd(const bool final_use) { From 5b3585c4fad2452d6cffedeea204f5f33f77772c Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 18 Aug 2025 13:14:23 -0700 Subject: [PATCH 2/4] [ET-VK] Runtime support for NamedDataMap Pull Request resolved: https://github.com/pytorch/executorch/pull/13472 Allow VulkanBackend to load constant tensors from the NamedDataMap instead of the constant data section of the delegate blob. ## Motivation This enables several key results: * Unblocks delegate retargetability with other backends * Allows reducing peak memory usage when loading models by freeing constant weight data as it gets moved to the GPU ## Changes * Allow `TensorRef` to be constructed with a `FreeableBuffer` rvalue * Add ability to load constant data from `NamedDataMap` in `VulkanBackend.cpp` * When prepacking, free the constant data pointer once it's been copied to the staging buffer ghstack-source-id: 303830113 Differential Revision: [D80460035](https://our.internmc.facebook.com/intern/diff/D80460035/) --- backends/vulkan/runtime/VulkanBackend.cpp | 52 +++++++++++++------ .../vulkan/runtime/graph/ComputeGraph.cpp | 11 ++++ backends/vulkan/runtime/graph/ComputeGraph.h | 10 ++++ backends/vulkan/runtime/graph/Logging.cpp | 2 +- .../runtime/graph/containers/Constant.cpp | 17 +++++- .../runtime/graph/containers/Constant.h | 17 ++++++ .../vulkan/runtime/graph/ops/PrepackNode.cpp | 3 ++ backends/vulkan/serialization/schema.fbs | 1 + .../serialization/vulkan_graph_schema.py | 1 + backends/vulkan/targets.bzl | 3 +- .../vulkan/test/vulkan_compute_api_test.cpp | 8 +-- 11 files changed, 102 insertions(+), 23 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 73b726bd32e..7b138072d50 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -22,6 +22,7 @@ #include #endif // ET_EVENT_TRACER_ENABLED #include +#include #include #include @@ -47,6 +48,7 @@ using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::FreeableBuffer; using executorch::runtime::kTensorDimensionLimit; +using executorch::runtime::NamedDataMap; using executorch::runtime::Result; using executorch::runtime::Span; @@ -66,14 +68,6 @@ using BytesVector = const flatbuffers::Vector>*; using UIntVector = const flatbuffers::Vector*; -const uint8_t* get_constant_data_ptr( - VkGraphPtr flatbuffer_graph, - const int32_t buffer_idx, - const uint8_t* constant_data) { - VkBytesPtr constant_bytes = flatbuffer_graph->constants()->Get(buffer_idx); - return constant_data + constant_bytes->offset(); -} - vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { switch (vk_datatype) { case vkgraph::VkDataType::BOOL: @@ -166,6 +160,8 @@ class GraphBuilder { ComputeGraph* compute_graph_; VkGraphPtr flatbuffer_; const uint8_t* constant_data_; + const NamedDataMap* named_data_map_; + std::vector loaded_buffers_from_map_; std::vector ref_mapping_; @@ -173,10 +169,13 @@ class GraphBuilder { explicit GraphBuilder( ComputeGraph* compute_graph, VkGraphPtr flatbuffer, - const uint8_t* constant_data) + const uint8_t* constant_data, + const NamedDataMap* named_data_map) : compute_graph_(compute_graph), flatbuffer_(flatbuffer), constant_data_(constant_data), + named_data_map_(named_data_map), + loaded_buffers_from_map_(), ref_mapping_() {} void resize(uint32_t size) { @@ -212,10 +211,27 @@ class GraphBuilder { ValueRef ref; if (tensor_fb->constant_id() >= 0) { - const uint8_t* tensor_data = get_constant_data_ptr( - flatbuffer_, tensor_fb->constant_id(), constant_data_); + VkBytesPtr constant_bytes = + flatbuffer_->constants()->Get(tensor_fb->constant_id()); - ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data); + if (constant_bytes->named_key() != nullptr && + constant_bytes->offset() == UINT64_MAX && + named_data_map_ != nullptr) { + const std::string& data_name = constant_bytes->named_key()->str(); + Result buffer = + named_data_map_->get_data(data_name.c_str()); + + VK_CHECK_COND( + buffer.ok(), + "Failed to get constant data for key %s from named_data_map. Error code: %u", + data_name.c_str(), + static_cast(buffer.error())); + ref = compute_graph_->add_tensorref( + dims_vector, dtype, std::move(buffer.get())); + } else { + const uint8_t* tensor_data = constant_data_ + constant_bytes->offset(); + ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data); + } } else { ref = compute_graph_->add_tensor( dims_vector, @@ -479,8 +495,10 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { return true; } - ET_NODISCARD Error - compileModel(const void* buffer_pointer, ComputeGraph* compute_graph) const { + ET_NODISCARD Error compileModel( + const void* buffer_pointer, + ComputeGraph* compute_graph, + const NamedDataMap* named_data_map) const { Result header = VulkanDelegateHeader::parse(buffer_pointer); @@ -506,7 +524,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data); - GraphBuilder builder(compute_graph, flatbuffer_graph, constant_data); + GraphBuilder builder( + compute_graph, flatbuffer_graph, constant_data, named_data_map); builder.build_graph(); @@ -532,7 +551,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { graph_config.external_adapter = vkapi::set_and_get_external_adapter(); new (compute_graph) ComputeGraph(graph_config); - Error err = compileModel(processed->data(), compute_graph); + const NamedDataMap* named_data_map = context.get_named_data_map(); + Error err = compileModel(processed->data(), compute_graph, named_data_map); // This backend does not need its processed data after compiling the // model. diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 33bfe8e3675..d57ba2b11d7 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -480,6 +480,17 @@ ValueRef ComputeGraph::add_tensorref( return idx; } +ValueRef ComputeGraph::add_tensorref( + const std::vector& sizes, + const vkapi::ScalarType dtype, + executorch::runtime::FreeableBuffer&& buffer) { + ValueRef idx(static_cast(values_.size())); + check_no_active_value_ptrs(); + values_.emplace_back(TensorRef(sizes, dtype, std::move(buffer))); + total_constant_nbytes_ += values_.back().toConstTensorRef().nbytes(); + return idx; +} + ValueRef ComputeGraph::add_staging( const vkapi::ScalarType dtype, const size_t numel) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index e4556a9efe6..f594571f9a7 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -693,6 +693,16 @@ class ComputeGraph final { const vkapi::ScalarType dtype, const void* const data); + /* + * Add a `TensorRef` value to the graph with the specific properties. A + * `TensorRef` is a reference to a `api::vTensor` whose data is stored in a + * FreeableBuffer. The TensorRef will take ownership of the FreeableBuffer. + */ + ValueRef add_tensorref( + const std::vector& sizes, + const vkapi::ScalarType dtype, + executorch::runtime::FreeableBuffer&& buffer); + /* * Add a staging buffer to the graph. Staging buffers are data buffers that * use memory that is visible to both the CPU and GPU, and therefore is used diff --git a/backends/vulkan/runtime/graph/Logging.cpp b/backends/vulkan/runtime/graph/Logging.cpp index 7102345773c..081083e3a63 100644 --- a/backends/vulkan/runtime/graph/Logging.cpp +++ b/backends/vulkan/runtime/graph/Logging.cpp @@ -86,7 +86,7 @@ void ComputeGraph::print_readable() { ss << v_tensor.sizes(); std::cout << ss.str(); } else if (val.isTensorRef()) { - const TensorRef tensor_ref = val.toTensorRef(); + const TensorRef& tensor_ref = val.toTensorRef(); std::stringstream ss; ss << tensor_ref.sizes; std::cout << ss.str(); diff --git a/backends/vulkan/runtime/graph/containers/Constant.cpp b/backends/vulkan/runtime/graph/containers/Constant.cpp index cb43295a42a..4dc2cdda8f5 100644 --- a/backends/vulkan/runtime/graph/containers/Constant.cpp +++ b/backends/vulkan/runtime/graph/containers/Constant.cpp @@ -14,7 +14,22 @@ TensorRef::TensorRef( const std::vector& t_sizes, vkapi::ScalarType t_dtype, const void* const t_data) - : sizes{}, dtype{t_dtype}, data{t_data} { + : sizes{}, dtype{t_dtype}, data{t_data}, buffer{} { + size_t ndim = t_sizes.size(); + sizes.resize(ndim); + for (int i = 0; i < ndim; ++i) { + sizes[i] = t_sizes.at(i); + } +} + +TensorRef::TensorRef( + const std::vector& t_sizes, + vkapi::ScalarType t_dtype, + executorch::runtime::FreeableBuffer&& t_buffer) + : sizes{}, + dtype{t_dtype}, + data{t_buffer.data()}, + buffer{std::move(t_buffer)} { size_t ndim = t_sizes.size(); sizes.resize(ndim); for (int i = 0; i < ndim; ++i) { diff --git a/backends/vulkan/runtime/graph/containers/Constant.h b/backends/vulkan/runtime/graph/containers/Constant.h index aaa92360a9e..a18c284a219 100644 --- a/backends/vulkan/runtime/graph/containers/Constant.h +++ b/backends/vulkan/runtime/graph/containers/Constant.h @@ -9,6 +9,7 @@ #pragma once #include +#include namespace vkcompute { @@ -24,14 +25,30 @@ struct TensorRef final { vkapi::ScalarType dtype; const void* data; + // Optional FreeableBuffer for managing memory lifecycle + // This will be empty (default constructed) for the raw pointer constructor + executorch::runtime::FreeableBuffer buffer; + explicit TensorRef( const std::vector& t_sizes, vkapi::ScalarType t_dtype, const void* const t_data); + // Constructor that takes ownership of a FreeableBuffer + explicit TensorRef( + const std::vector& t_sizes, + vkapi::ScalarType t_dtype, + executorch::runtime::FreeableBuffer&& t_buffer); + inline size_t nbytes() const { return utils::multiply_integers(sizes) * vkapi::element_size(dtype); } + + // Manually free the buffer if needed (though it will be freed automatically + // on destruction) + void free_buffer() { + buffer.Free(); + } }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index c8220df837b..03df92292f8 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -64,6 +64,9 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { graph->update_staging_nbytes_in_cmd(staging.buffer().mem_size_as_size_t()); size_t nbytes = numel * vkapi::element_size(tref->dtype); staging.copy_from(tref->data, nbytes); + // Once the staging buffer is copied, if the TensorRef owns a FreeableBuffer, + // it can be freed. + tref->free_buffer(); return staging; } diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index 99ba6a86594..b6670b6f53d 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -118,6 +118,7 @@ table VkValue { table VkBytes { offset:ulong; length:ulong; + named_key:string; } table VkGraph { diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index f845e5601a7..aa7641bd927 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -137,6 +137,7 @@ class VkValue: class VkBytes: offset: int length: int + named_key: str = "" @dataclass diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index ac26d202fe1..b9b96abdec4 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -263,6 +263,7 @@ def define_common_targets(is_fbcode = False): ], exported_deps = [ ":vulkan_graph_runtime_shaderlib{}".format(suffix), + "//executorch/runtime/backend:interface", ], define_static_target = True, # Static initialization is used to register operators to the global operator registry, @@ -303,8 +304,8 @@ def define_common_targets(is_fbcode = False): ":vulkan_graph_runtime{}".format(suffix), "//executorch/backends/vulkan/serialization:vk_delegate_schema", "//executorch/runtime/core:event_tracer", - "//executorch/runtime/backend:interface", "//executorch/runtime/core/exec_aten/util:tensor_util", + "//executorch/runtime/core:named_data_map", ], define_static_target = True, # VulkanBackend.cpp needs to compile with executor as whole diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index f99552ceee1..96adc13d3cd 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1036,12 +1036,12 @@ TEST_F(VulkanComputeAPITest, print_object_sizes) { // Current known size on 64 bit system: 1040 B EXPECT_TRUE(sizeof(vTensor) < 1200); - // Current known size on 64 bit system: 48 B - EXPECT_TRUE(sizeof(Value) < 56); + // Current known size on 64 bit system: 80 B + EXPECT_TRUE(sizeof(Value) < 100); // Current known size on 64 bit system: 120 B EXPECT_TRUE(sizeof(StagingBuffer) < 500); - // Current known size on 64 bit system: 512 B - EXPECT_TRUE(sizeof(ComputeGraph) < 600); + // Current known size on 64 bit system: 608 B + EXPECT_TRUE(sizeof(ComputeGraph) < 700); // Current known size on 64 bit system: 248 B EXPECT_TRUE(sizeof(DispatchNode) < 500); } From ab383fd1c3b0fb877d80b322455a09739612ae0d Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 18 Aug 2025 13:14:25 -0700 Subject: [PATCH 3/4] [ET-VK][AOT] Serialize constant tensors via NamedDataMap Pull Request resolved: https://github.com/pytorch/executorch/pull/13473 When exporting models to Vulkan backend, save constant tensors in the NamedDataMap instead of the constant data section of the delegate header. ## Motivation Prevent screen blackout (Llama 3.2 1B) / device crash (Llama 3.2 3B) when running Llama 3.2 models on Samsung Galaxy S24. This behaviour is related to high peak memory usage when loading the model. For more information, see the top diff/PR in the stack. ## Context This change is based on the equivalent change D70315207/https://github.com/pytorch/executorch/pull/9153 in XNNPACK. ghstack-source-id: 303830114 Differential Revision: [D80460034](https://our.internmc.facebook.com/intern/diff/D80460034/) --- .../serialization/vulkan_graph_builder.py | 36 +++++++++++++++++-- .../serialization/vulkan_graph_serialize.py | 19 ++++++++-- backends/vulkan/vulkan_preprocess.py | 1 + 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index b74a7fb1f8e..78ac51c8808 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import ctypes +import hashlib import logging import operator from types import NoneType @@ -25,6 +27,7 @@ is_symint_node, TensorRepr, ) +from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir.backend.utils import DelegateMappingBuilder from executorch.exir.tensor import TensorSpec @@ -56,6 +59,7 @@ def __init__( self.input_ids = [] self.output_ids = [] self.const_tensors = [] + self.named_data_store = NamedDataStore() # Mapping from Node to VkValue id self.node_to_value_ids = {} @@ -129,8 +133,36 @@ def get_param_tensor(self, node: Node) -> torch.Tensor: def maybe_add_constant_tensor(self, node: Node) -> int: constant_id = -1 if is_param_node(self.program, node): - constant_id = len(self.const_tensors) - self.const_tensors.append(self.get_param_tensor(node)) + tensor = self.get_param_tensor(node) + + # Serialize tensor data to bytes + tensor = tensor.contiguous() + size = tensor.untyped_storage().nbytes() + + if size > 0: + array_type = ctypes.c_char * size + array = ctypes.cast( + tensor.untyped_storage().data_ptr(), + ctypes.POINTER(array_type), + ).contents + + # Generate SHA256 hash as the named key + tensor_bytes = bytes(array) + sha256_hash = hashlib.sha256(tensor_bytes) + named_key = sha256_hash.hexdigest() + + # Add to named data store with 16-byte alignment (matching XNNPACK) + self.named_data_store.add_named_data( + named_key, tensor_bytes, alignment=16 + ) + + # Create VkBytes entry with named_key and set offset to indicate named data usage + constant_id = len(self.const_tensors) + self.const_tensors.append((named_key, size)) + else: + # Handle empty tensors + constant_id = len(self.const_tensors) + self.const_tensors.append(None) return constant_id diff --git a/backends/vulkan/serialization/vulkan_graph_serialize.py b/backends/vulkan/serialization/vulkan_graph_serialize.py index 2ceedf73d10..db682f4e67e 100644 --- a/backends/vulkan/serialization/vulkan_graph_serialize.py +++ b/backends/vulkan/serialization/vulkan_graph_serialize.py @@ -191,10 +191,21 @@ def serialize_constant_tensors( current_offset = len(raw_bytes) for tensor in const_tensors: - if tensor.numel() == 0: + # The tensor data is stored in the named data map + if isinstance(tensor, tuple): + named_key, size = tensor + vk_graph.constants.append( + VkBytes( + offset=18446744073709551615, # UINT64_MAX to indicate named data + length=size, + named_key=named_key, + ) + ) + elif tensor is None or ( + isinstance(tensor, torch.Tensor) and tensor.numel() == 0 + ): vk_graph.constants.append(VkBytes(current_offset, 0)) - continue - else: + elif isinstance(tensor, torch.Tensor): array_type = ctypes.c_char * tensor.untyped_storage().nbytes() array = ctypes.cast( tensor.untyped_storage().data_ptr(), @@ -208,6 +219,8 @@ def serialize_constant_tensors( vk_graph.constants.append(VkBytes(current_offset, len(tensor_bytes))) current_offset += aligned_size(len(tensor_bytes)) + else: + raise ValueError(f"Unsupported constant tensor type: {type(tensor)}") def serialize_custom_shaders( diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 8c1165a89df..1816d9b12de 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -229,4 +229,5 @@ def preprocess( # noqa: C901 vk_graph, graph_builder.const_tensors, [] ), debug_handle_map=graph_builder.delegate_mapping_builder.get_delegate_mapping(), + data_store_output=graph_builder.named_data_store.get_named_data_store_output(), ) From 6c18621c4c96a02310ee53ead9648226cfcc1ca6 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 18 Aug 2025 14:53:24 -0700 Subject: [PATCH 4/4] [ET-VK] Allocate memory for weight and activation tensors lazily Pull Request resolved: https://github.com/pytorch/executorch/pull/13474 * Allocate memory for weight tensors right before the prepacking shader is dispatched, rather than while building the graph * Move allocation of shared objects (i.e. memory for intermediate tensors) to occur after prepacking ## Motivation Prevent screen blackout (Llama 3.2 1B) / device crash (Llama 3.2 3B) when running Llama 3.2 models on Samsung Galaxy S24. This behaviour is related to high peak memory usage when loading the model. ## Full Context During model loading, Vulkan delegate needs to store 3 copies of constant data in memory at various points: * source data obtained from loading the model * staging buffer * GPU texture/buffer The general rationale of this change is to allocate memory for each copy only when necessary to minimize the "overlap" when all 3 exist at once. ### Current Order of operations Legend: * `W` represents total weight nbytes * `w` represents weight nbytes for one tensor * `A` represents total activations nbytes * `M` represents approximation of total memory footprint First, model file is loaded Then, when building compute graph, for each weight tensor: 1. Weight data is loaded from NamedDataMap (`M = W`) 2. GPU texture/buffer for weight is initialized + memory allocated (`M = 2W`) 3. After building the graph, `graph->prepare()` is called which currently allocates memory for the activation tensors as well (`M = 2W + A`) Then, during the prepacking stage for each weight tensor, each weight tensor is copied individually: 1. Staging buffer initialized (`M = 2W + A + w`) 2. Copy CPU weight data to staging + CPU Weight data is freed (`M = 2W + A`) 3. Compute shader dispatch to copy staging to GPU texture/buffer + free staging buffer (`M = 2W + A - w`) The peak usage in mainline will be `M = 2W + A + w` ### Revised order of operations This change revises the order of operations: 1. Weight data is loaded from NamedDataMap (`M = W`) 2. GPU texture/buffer for weight is initialized, but **memory is not allocated** (`M = W`) Then, during the prepacking stage for each weight tensor, each weight tensor is copied individually: 1. Staging buffer initialized (`M = W + w`) 2. **Memory allocated for GPU texture/buffer** (`M = W + 2w`) 3. Copy CPU weight data to staging + CPU Weight data is freed (`M = W + w`) 4. Compute shader dispatch to copy staging to GPU texture/buffer + free staging buffer (`M = W`) **Then, after all prepacking operations complete, only then is Activation memory allocated** (`M = W + A`) Under this scheme, peak memory is reduced to `M = W + A` (or alternatively `M = W + 2w` if `2w > A`) which is (or at least very close to) the theoretical minimum. ghstack-source-id: 303862303 Differential Revision: [D80460033](https://our.internmc.facebook.com/intern/diff/D80460033/) --- .../vulkan/runtime/api/containers/Tensor.cpp | 22 ++++++++ .../vulkan/runtime/api/containers/Tensor.h | 11 ++++ .../vulkan/runtime/graph/ComputeGraph.cpp | 34 +++++++++---- backends/vulkan/runtime/graph/ComputeGraph.h | 7 +++ .../vulkan/runtime/graph/ops/PrepackNode.cpp | 4 ++ .../vulkan/runtime/vk_api/memory/Buffer.cpp | 43 ++++++++++++++-- .../vulkan/runtime/vk_api/memory/Buffer.h | 26 +++++++--- .../vulkan/runtime/vk_api/memory/Image.cpp | 51 +++++++++++++++++-- backends/vulkan/runtime/vk_api/memory/Image.h | 34 +++++++------ .../vulkan/test/vulkan_compute_api_test.cpp | 18 +++++-- 10 files changed, 206 insertions(+), 44 deletions(-) diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index a3d9bd4aa34..6f7167c54fb 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -897,6 +897,16 @@ VkMemoryRequirements vTensor::get_memory_requirements() const { return {}; } +bool vTensor::memory_is_bound() const { + switch (storage_type()) { + case utils::kBuffer: + return storage_->buffer_.has_memory(); + case utils::kTexture2D: + case utils::kTexture3D: + return storage_->image_.has_memory(); + } +} + void vTensor::bind_allocation(const vkapi::Allocation& allocation) { switch (storage_type()) { case utils::kBuffer: @@ -909,6 +919,18 @@ void vTensor::bind_allocation(const vkapi::Allocation& allocation) { } } +void vTensor::acquire_allocation(vkapi::Allocation&& allocation) { + switch (storage_type()) { + case utils::kBuffer: + storage_->buffer_.acquire_allocation(std::move(allocation)); + break; + case utils::kTexture2D: + case utils::kTexture3D: + storage_->image_.acquire_allocation(std::move(allocation)); + break; + } +} + void vTensor::update_metadata() { numel_ = utils::multiply_integers(sizes_); strides_ = calculate_strides(sizes_, dim_order_); diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 0e1a1526d88..bcca956e5ea 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -560,6 +560,12 @@ class vTensor final { */ VmaAllocationCreateInfo get_allocation_create_info() const; + /* + * Checks if the tensor's underlying buffer or image resource is bound to a + * memory allocation. + */ + bool memory_is_bound() const; + /* * Return the VkMemoryRequirements of the underlying resource */ @@ -570,6 +576,11 @@ class vTensor final { */ void bind_allocation(const vkapi::Allocation& allocation); + /* + * Binds and acquires a rvalue memory allocation + */ + void acquire_allocation(vkapi::Allocation&& allocation); + private: /* * Assuming sizes, dim order, or axis mapping was modified, recompute all diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index d57ba2b11d7..fff530d57cb 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -356,8 +356,6 @@ ValueRef ComputeGraph::add_tensor( const utils::GPUMemoryLayout memory_layout, const int64_t shared_object_idx, const utils::AxisMapLayout axis_map_layout) { - bool allocate_memory = shared_object_idx < 0; - ValueRef idx(static_cast(values_.size())); check_no_active_value_ptrs(); values_.emplace_back(api::vTensor( @@ -366,10 +364,10 @@ ValueRef ComputeGraph::add_tensor( dtype, storage_type, memory_layout, - allocate_memory, + false, axis_map_layout)); - if (!allocate_memory) { + if (shared_object_idx >= 0) { get_shared_object(shared_object_idx).add_user(this, idx); } return idx; @@ -626,6 +624,17 @@ SharedObject& ComputeGraph::get_shared_object(const int64_t idx) { return shared_objects_.at(idx); } +void ComputeGraph::create_dedicated_allocation_for(const ValueRef idx) { + vTensorPtr tensor = get_tensor(idx); + if (!tensor->memory_is_bound()) { + VmaAllocationCreateInfo alloc_create_info = + context()->adapter_ptr()->vma().gpuonly_resource_create_info(); + tensor->acquire_allocation( + context()->adapter_ptr()->vma().create_allocation( + tensor->get_memory_requirements(), alloc_create_info)); + } +} + void ComputeGraph::update_descriptor_counts( const vkapi::ShaderInfo& shader_info, bool execute) { @@ -852,11 +861,6 @@ void ComputeGraph::prepare() { } execute_threshold_node_count_ = count_threshold; - - for (SharedObject& shared_object : shared_objects_) { - shared_object.allocate(this); - shared_object.bind_users(this); - } } void ComputeGraph::prepare_pipelines() { @@ -952,6 +956,18 @@ void ComputeGraph::prepack() { submit_current_cmd_and_wait(/*final_use=*/true); context_->flush(); staging_nbytes_in_cmd_ = 0; + + // Initialize allocations for intermediate tensors + for (SharedObject& shared_object : shared_objects_) { + shared_object.allocate(this); + shared_object.bind_users(this); + } + // Make sure all remaining tensors have allocations + for (int i = 0; i < values_.size(); i++) { + if (values_.at(i).isTensor()) { + create_dedicated_allocation_for(i); + } + } } void ComputeGraph::execute() { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index f594571f9a7..7686aa65025 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -827,6 +827,13 @@ class ComputeGraph final { SharedObject& get_shared_object(const int64_t idx); + /* + * Creates a dedicated memory allocation for a vTensor value, and have the + * tensor acquire the allocation object. If the tensor is already bound to a + * memory allocation, this function will be a no-op. + */ + void create_dedicated_allocation_for(const ValueRef idx); + // // Graph Preparation // diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index 03df92292f8..62e1dc86f43 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -97,6 +97,10 @@ void PrepackNode::encode(ComputeGraph* graph) { } { + // If the vTensor is not yet bound to a memory allocation, create a new one + // and aquire it. + graph->create_dedicated_allocation_for(packed_); + vkapi::PipelineBarrier pipeline_barrier{}; vkapi::DescriptorSet descriptor_set = context->get_descriptor_set( shader_, local_workgroup_size_, spec_vars_, push_constants_offset); diff --git a/backends/vulkan/runtime/vk_api/memory/Buffer.cpp b/backends/vulkan/runtime/vk_api/memory/Buffer.cpp index 4f58e07b146..f10e40abdbb 100644 --- a/backends/vulkan/runtime/vk_api/memory/Buffer.cpp +++ b/backends/vulkan/runtime/vk_api/memory/Buffer.cpp @@ -20,6 +20,7 @@ VulkanBuffer::VulkanBuffer() allocator_(VK_NULL_HANDLE), memory_{}, owns_memory_(false), + memory_bundled_(false), is_copy_(false), handle_(VK_NULL_HANDLE) {} @@ -33,6 +34,7 @@ VulkanBuffer::VulkanBuffer( allocator_(vma_allocator), memory_{}, owns_memory_(allocate_memory), + memory_bundled_(allocate_memory), is_copy_(false), handle_(VK_NULL_HANDLE) { // If the buffer size is 0, allocate a buffer with a size of 1 byte. This is @@ -77,6 +79,7 @@ VulkanBuffer::VulkanBuffer( allocator_(other.allocator_), memory_(other.memory_), owns_memory_(false), + memory_bundled_(false), is_copy_(true), handle_(other.handle_) { // TODO: set the offset and range appropriately @@ -91,6 +94,7 @@ VulkanBuffer::VulkanBuffer(VulkanBuffer&& other) noexcept allocator_(other.allocator_), memory_(std::move(other.memory_)), owns_memory_(other.owns_memory_), + memory_bundled_(other.memory_bundled_), is_copy_(other.is_copy_), handle_(other.handle_) { other.handle_ = VK_NULL_HANDLE; @@ -99,16 +103,19 @@ VulkanBuffer::VulkanBuffer(VulkanBuffer&& other) noexcept VulkanBuffer& VulkanBuffer::operator=(VulkanBuffer&& other) noexcept { VkBuffer tmp_buffer = handle_; bool tmp_owns_memory = owns_memory_; + bool tmp_memory_bundled = memory_bundled_; buffer_properties_ = other.buffer_properties_; allocator_ = other.allocator_; memory_ = std::move(other.memory_); owns_memory_ = other.owns_memory_; + memory_bundled_ = other.memory_bundled_; is_copy_ = other.is_copy_; handle_ = other.handle_; other.handle_ = tmp_buffer; other.owns_memory_ = tmp_owns_memory; + other.memory_bundled_ = tmp_memory_bundled; return *this; } @@ -119,14 +126,22 @@ VulkanBuffer::~VulkanBuffer() { // ownership of the underlying resource. if (handle_ != VK_NULL_HANDLE && !is_copy_) { if (owns_memory_) { - vmaDestroyBuffer(allocator_, handle_, memory_.allocation); + if (memory_bundled_) { + vmaDestroyBuffer(allocator_, handle_, memory_.allocation); + // Prevent the underlying memory allocation from being freed; it was + // freed by vmaDestroyImage + memory_.allocation = VK_NULL_HANDLE; + } else { + vkDestroyBuffer(this->device(), handle_, nullptr); + // Allow underlying memory allocation to be freed by the destructor of + // Allocation class + } } else { vkDestroyBuffer(this->device(), handle_, nullptr); + // Prevent the underlying memory allocation from being freed since this + // object doesn't own it + memory_.allocation = VK_NULL_HANDLE; } - // Prevent the underlying memory allocation from being freed; it was either - // freed by vmaDestroyBuffer, or this resource does not own the underlying - // memory - memory_.allocation = VK_NULL_HANDLE; } } @@ -136,6 +151,24 @@ VmaAllocationInfo VulkanBuffer::allocation_info() const { return info; } +void VulkanBuffer::bind_allocation_impl(const Allocation& memory) { + VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!"); + if (!is_copy_) { + VK_CHECK(vmaBindBufferMemory(allocator_, memory.allocation, handle_)); + } +} + +void VulkanBuffer::bind_allocation(const Allocation& memory) { + bind_allocation_impl(memory); + memory_.allocation = memory.allocation; +} + +void VulkanBuffer::acquire_allocation(Allocation&& memory) { + bind_allocation_impl(memory); + memory_ = std::move(memory); + owns_memory_ = true; +} + VkMemoryRequirements VulkanBuffer::get_memory_requirements() const { VkMemoryRequirements memory_requirements; vkGetBufferMemoryRequirements(this->device(), handle_, &memory_requirements); diff --git a/backends/vulkan/runtime/vk_api/memory/Buffer.h b/backends/vulkan/runtime/vk_api/memory/Buffer.h index e1b441397b4..582b537465d 100644 --- a/backends/vulkan/runtime/vk_api/memory/Buffer.h +++ b/backends/vulkan/runtime/vk_api/memory/Buffer.h @@ -100,6 +100,10 @@ class VulkanBuffer final { Allocation memory_; // Indicates whether the underlying memory is owned by this resource bool owns_memory_; + // Indicates whether the allocation for the buffer was created with the buffer + // via vmaCreateBuffer; if this is false, the memory is owned but was bound + // separately via vmaBindBufferMemory + bool memory_bundled_; // Indicates whether this VulkanBuffer was copied from another VulkanBuffer, // thus it does not have ownership of the underlying VKBuffer bool is_copy_; @@ -162,13 +166,21 @@ class VulkanBuffer final { return (handle_ == other.handle_) && is_copy_; } - inline void bind_allocation(const Allocation& memory) { - VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!"); - if (!is_copy_) { - VK_CHECK(vmaBindBufferMemory(allocator_, memory.allocation, handle_)); - } - memory_.allocation = memory.allocation; - } + private: + void bind_allocation_impl(const Allocation& memory); + + public: + /* + * Given a memory allocation, bind it to the underlying VkImage. The lifetime + * of the memory allocation is assumed to be managed externally. + */ + void bind_allocation(const Allocation& memory); + + /* + * Given a rvalue memory allocation, bind it to the underlying VkImage and + * also acquire ownership of the memory allocation. + */ + void acquire_allocation(Allocation&& memory); VkMemoryRequirements get_memory_requirements() const; diff --git a/backends/vulkan/runtime/vk_api/memory/Image.cpp b/backends/vulkan/runtime/vk_api/memory/Image.cpp index da6ff76bccd..cadeb779c83 100644 --- a/backends/vulkan/runtime/vk_api/memory/Image.cpp +++ b/backends/vulkan/runtime/vk_api/memory/Image.cpp @@ -99,6 +99,7 @@ VulkanImage::VulkanImage() allocator_(VK_NULL_HANDLE), memory_{}, owns_memory_(false), + memory_bundled_(false), owns_view_(false), is_copy_(false), handles_{ @@ -125,6 +126,7 @@ VulkanImage::VulkanImage( allocator_(vma_allocator), memory_{}, owns_memory_{allocate_memory}, + memory_bundled_(allocate_memory), owns_view_(false), is_copy_(false), handles_{ @@ -195,6 +197,7 @@ VulkanImage::VulkanImage( allocator_(VK_NULL_HANDLE), memory_{}, owns_memory_(false), + memory_bundled_(false), is_copy_(false), handles_{ image, @@ -224,6 +227,7 @@ VulkanImage::VulkanImage(VulkanImage&& other) noexcept allocator_(other.allocator_), memory_(std::move(other.memory_)), owns_memory_(other.owns_memory_), + memory_bundled_(other.memory_bundled_), owns_view_(other.owns_view_), is_copy_(other.is_copy_), handles_(other.handles_), @@ -232,12 +236,14 @@ VulkanImage::VulkanImage(VulkanImage&& other) noexcept other.handles_.image_view = VK_NULL_HANDLE; other.handles_.sampler = VK_NULL_HANDLE; other.owns_memory_ = false; + other.memory_bundled_ = false; } VulkanImage& VulkanImage::operator=(VulkanImage&& other) noexcept { VkImage tmp_image = handles_.image; VkImageView tmp_image_view = handles_.image_view; bool tmp_owns_memory = owns_memory_; + bool tmp_memory_bundled = memory_bundled_; device_ = other.device_; image_properties_ = other.image_properties_; @@ -246,6 +252,7 @@ VulkanImage& VulkanImage::operator=(VulkanImage&& other) noexcept { allocator_ = other.allocator_; memory_ = std::move(other.memory_); owns_memory_ = other.owns_memory_; + memory_bundled_ = other.memory_bundled_; is_copy_ = other.is_copy_; handles_ = other.handles_; layout_ = other.layout_; @@ -253,6 +260,7 @@ VulkanImage& VulkanImage::operator=(VulkanImage&& other) noexcept { other.handles_.image = tmp_image; other.handles_.image_view = tmp_image_view; other.owns_memory_ = tmp_owns_memory; + other.memory_bundled_ = tmp_memory_bundled; return *this; } @@ -271,14 +279,22 @@ VulkanImage::~VulkanImage() { if (handles_.image != VK_NULL_HANDLE) { if (owns_memory_) { - vmaDestroyImage(allocator_, handles_.image, memory_.allocation); + if (memory_bundled_) { + vmaDestroyImage(allocator_, handles_.image, memory_.allocation); + // Prevent the underlying memory allocation from being freed; it was + // freed by vmaDestroyImage + memory_.allocation = VK_NULL_HANDLE; + } else { + vkDestroyImage(this->device(), handles_.image, nullptr); + // Allow underlying memory allocation to be freed by the destructor of + // Allocation class + } } else { vkDestroyImage(this->device(), handles_.image, nullptr); + // Prevent the underlying memory allocation from being freed since this + // object doesn't own it + memory_.allocation = VK_NULL_HANDLE; } - // Prevent the underlying memory allocation from being freed; it was either - // freed by vmaDestroyImage, or this resource does not own the underlying - // memory - memory_.allocation = VK_NULL_HANDLE; } } @@ -319,6 +335,31 @@ void VulkanImage::create_image_view() { &(handles_.image_view))); } +void VulkanImage::bind_allocation_impl(const Allocation& memory) { + VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!"); + // To prevent multiple instances of binding the same VkImage to a memory + // block, do not actually bind memory if this VulkanImage is a copy. Assume + // that the original VulkanImage is responsible for binding the image. + if (!is_copy_) { + VK_CHECK(vmaBindImageMemory(allocator_, memory.allocation, handles_.image)); + } + + // Only create the image view if the image has been bound to memory + owns_view_ = true; + create_image_view(); +} + +void VulkanImage::bind_allocation(const Allocation& memory) { + bind_allocation_impl(memory); + memory_.allocation = memory.allocation; +} + +void VulkanImage::acquire_allocation(Allocation&& memory) { + bind_allocation_impl(memory); + memory_ = std::move(memory); + owns_memory_ = true; +} + VkMemoryRequirements VulkanImage::get_memory_requirements() const { VkMemoryRequirements memory_requirements; vkGetImageMemoryRequirements( diff --git a/backends/vulkan/runtime/vk_api/memory/Image.h b/backends/vulkan/runtime/vk_api/memory/Image.h index 5bbdaf06b47..db632c34378 100644 --- a/backends/vulkan/runtime/vk_api/memory/Image.h +++ b/backends/vulkan/runtime/vk_api/memory/Image.h @@ -156,6 +156,10 @@ class VulkanImage final { Allocation memory_; // Indicates whether the underlying memory is owned by this resource bool owns_memory_; + // Indicates whether the allocation for the image was created with the image + // via vmaCreateImage; if this is false, the memory is owned but was bound + // separately via vmaBindImageMemory + bool memory_bundled_; // In some cases, a VulkanImage may be a copy of another VulkanImage but still // own a unique view of the VkImage. bool owns_view_; @@ -242,21 +246,21 @@ class VulkanImage final { return (handles_.image == other.handles_.image) && is_copy_; } - inline void bind_allocation(const Allocation& memory) { - VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!"); - // To prevent multiple instances of binding the same VkImage to a memory - // block, do not actually bind memory if this VulkanImage is a copy. Assume - // that the original VulkanImage is responsible for binding the image. - if (!is_copy_) { - VK_CHECK( - vmaBindImageMemory(allocator_, memory.allocation, handles_.image)); - } - memory_.allocation = memory.allocation; - - // Only create the image view if the image has been bound to memory - owns_view_ = true; - create_image_view(); - } + private: + void bind_allocation_impl(const Allocation& memory); + + public: + /* + * Given a memory allocation, bind it to the underlying VkImage. The lifetime + * of the memory allocation is assumed to be managed externally. + */ + void bind_allocation(const Allocation& memory); + + /* + * Given a rvalue memory allocation, bind it to the underlying VkImage and + * also acquire ownership of the memory allocation. + */ + void acquire_allocation(Allocation&& memory); VkMemoryRequirements get_memory_requirements() const; diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 96adc13d3cd..9a857f41fde 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1176,6 +1176,7 @@ TEST(VulkanComputeGraphTest, test_zero_dim_tensor) { out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); // Run graph @@ -1218,6 +1219,7 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_buffer) { out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); // Run graph @@ -1303,6 +1305,7 @@ TEST(VulkanComputeGraphTest, test_simple_graph) { out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); // Run graph @@ -1361,6 +1364,7 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_symint) { out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); // Run graph @@ -1519,6 +1523,7 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { EXPECT_EQ(get_vma_allocation_count(), expected_vma_allocation_count); graph.prepare(); + graph.prepack(); // +3: shared memory allocations for tensors expected_vma_allocation_count += 3; @@ -1659,6 +1664,7 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_tmp_tensors) { out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); // Run graph @@ -1725,6 +1731,7 @@ TEST(VulkanComputeGraphTest, test_large_graph) { out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); auto build_end_time = std::chrono::system_clock::now(); @@ -1801,6 +1808,7 @@ void test_clone( out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); fill_vtensor(graph, a, 0.0f, /*iota = */ true); @@ -1885,6 +1893,7 @@ TEST(VulkanComputeGraphTest, test_etvk_copy_offset_node) { out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); fill_vtensor(graph, a, 0.0f, /*iota = */ true); @@ -1948,6 +1957,7 @@ TEST(VulkanComputeGraphTest, DISABLED_test_etvk_copy_channel_offset_node) { out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); fill_vtensor(graph, a, 0.0f, true); @@ -2038,6 +2048,7 @@ TEST( out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); float a_value = 1.0f; float b_value = 2.0f; @@ -2150,6 +2161,7 @@ TEST(VulkanComputeGraphTest, test_etvk_copy_offset_int_node) { out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); fill_vtensor(graph, a, 0, /*iota = */ true); @@ -2213,6 +2225,7 @@ TEST(VulkanComputeGraphTest, DISABLED_test_etvk_copy_channel_offset_int_node) { out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); fill_vtensor(graph, a, 0.0f, true); @@ -2272,6 +2285,7 @@ TEST(VulkanComputeGraphTest, test_view_change_packing) { out.staging = graph.set_output_tensor(out.value); graph.prepare(); + graph.prepack(); fill_vtensor(graph, in, 0.0, true); @@ -2430,6 +2444,7 @@ void compute_graph_round_trip_test( ValueRef r_staging_out = graph.set_output_tensor(r_tensor); graph.prepare(); + graph.prepack(); std::vector data_in(graph.numel_of(r_tensor)); for (int i = 0; i < data_in.size(); i++) { @@ -2620,7 +2635,6 @@ void test_mm( B, M, K, N, dtype, storage_type, memory_layout, mat2_data, prepack); graph.prepare(); - graph.prepack(); for (int i = 1; i < 4; i++) { @@ -2700,7 +2714,6 @@ void test_mm_with_resize_reencode( B, M, K, N, dtype, storage_type, memory_layout, mat2_data, false); graph.prepare(); - graph.prepack(); for (int i = 1; i < 4; i++) { @@ -3122,7 +3135,6 @@ void test_dynamic_dispatch(int M, int N) { ComputeGraph graph = build_dynamic_dispatch_test_graph(M, N); graph.prepare(); - graph.prepack(); for (int i = 1; i < 4; i++) {