diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 75726ae0892..e84cb884e83 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -140,6 +140,14 @@ GraphConfig get_graph_config(ArrayRef& compile_specs) { config.set_memory_layout_override(memory_layout); } + if (strcmp(spec.key, "require_dynamic_shapes") == 0) { + ET_CHECK_MSG(value_size == sizeof(uint8_t), "Unexpected value size!"); + bool value = getBool(value_data); + + if (value) { + config.expect_dynamic_shapes = true; + } + } } #ifdef ET_EVENT_TRACER_ENABLED config.enable_querypool = true; @@ -500,9 +508,12 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->encode_prepack(); compute_graph->prepack(); - // TODO(ssjia): remove this once we can batch compile compute pipelines - // during prepare(). - compute_graph->encode_execute(); + // If dynamic shapes are not expected, then the command buffer only needs to + // be encoded once. Otherwise, wait until the first inference to encode the + // the command buffer, when actual input shapes are known. + if (!compute_graph->graphconfig().expect_dynamic_shapes) { + compute_graph->encode_execute(); + } return Error::Ok; } @@ -574,7 +585,9 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { // constants are updated and DynamicDispatchNode can update the compute // shader, global workgroup size, and local workgroup size to perform the // model inference. - if (should_propagate_resize) { + if (should_propagate_resize || + (compute_graph->graphconfig().expect_dynamic_shapes && + compute_graph->execute_count() == 0u)) { compute_graph->propagate_resize(); } diff --git a/backends/vulkan/runtime/VulkanDelegateHeader.cpp b/backends/vulkan/runtime/VulkanDelegateHeader.cpp index 81fd0bbc953..2a235144342 100644 --- a/backends/vulkan/runtime/VulkanDelegateHeader.cpp +++ b/backends/vulkan/runtime/VulkanDelegateHeader.cpp @@ -60,6 +60,10 @@ uint32_t getUInt16LE(const uint8_t* data) { return (uint32_t)data[0] | ((uint32_t)data[1] << 8); } +bool getBool(const uint8_t* data) { + return data[0] != 0; +} + bool VulkanDelegateHeader::is_valid() const { if (header_size < kExpectedSize) { return false; diff --git a/backends/vulkan/runtime/VulkanDelegateHeader.h b/backends/vulkan/runtime/VulkanDelegateHeader.h index 0fc163bbe3c..722f01cbb75 100644 --- a/backends/vulkan/runtime/VulkanDelegateHeader.h +++ b/backends/vulkan/runtime/VulkanDelegateHeader.h @@ -19,6 +19,9 @@ uint64_t getUInt64LE(const uint8_t* data); uint32_t getUInt32LE(const uint8_t* data); uint32_t getUInt16LE(const uint8_t* data); +// Bool is serialized as a single byte +bool getBool(const uint8_t* data); + struct VulkanDelegateHeader { bool is_valid() const; diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 5dc26286682..0fbeca36979 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -761,7 +761,10 @@ void ComputeGraph::propagate_resize() { for (std::unique_ptr& node : execute_nodes_) { node->trigger_resize(this); } - encode_execute(); + // Only re-encode on resize if dynamic shapes are expected + if (config_.expect_dynamic_shapes) { + encode_execute(); + } } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/GraphConfig.cpp b/backends/vulkan/runtime/graph/GraphConfig.cpp index 887b46c002a..20606be6a96 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.cpp +++ b/backends/vulkan/runtime/graph/GraphConfig.cpp @@ -63,6 +63,8 @@ GraphConfig::GraphConfig() { enable_local_wg_size_override = false; local_wg_size_override = {}; + + expect_dynamic_shapes = false; } void GraphConfig::set_storage_type_override(utils::StorageType storage_type) { diff --git a/backends/vulkan/runtime/graph/GraphConfig.h b/backends/vulkan/runtime/graph/GraphConfig.h index aa7df2cb413..df2d6d6f2e1 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.h +++ b/backends/vulkan/runtime/graph/GraphConfig.h @@ -33,6 +33,9 @@ struct GraphConfig final { bool enable_local_wg_size_override; utils::uvec3 local_wg_size_override; + // Whether or not the ComputeGraph should expect input shapes to be dynamic + bool expect_dynamic_shapes; + // Generate a default graph config with pre-configured settings explicit GraphConfig(); diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 447e5d039f4..4d9e3fe61ee 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -43,6 +43,9 @@ def lower_module( model: torch.nn.Module, sample_inputs: Tuple[torch.Tensor], dynamic_shapes=None ) -> EdgeProgramManager: compile_options = {} + if dynamic_shapes is not None: + compile_options["require_dynamic_shapes"] = True + edge_compile_config = EdgeCompileConfig( _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. ) @@ -70,6 +73,9 @@ def quantize_and_lower_module( dynamic_shapes=None, ) -> EdgeProgramManager: compile_options = {} + if dynamic_shapes is not None: + compile_options["require_dynamic_shapes"] = True + edge_compile_config = EdgeCompileConfig( _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. ) diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index 3497aeb5705..3f5dba9e277 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -512,6 +512,7 @@ vkcompute::ComputeGraph build_mm_graph( const bool prepack_mat2) { using namespace vkcompute; GraphConfig config; + config.expect_dynamic_shapes = true; ComputeGraph graph(config); std::vector mat1_size = {M, K}; diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index f89d4dca705..c4ccc860bc2 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -2937,6 +2937,7 @@ void test_transpose_view_mm( const int N, utils::StorageType storage_type) { GraphConfig config; + config.expect_dynamic_shapes = true; config.set_storage_type_override(storage_type); ComputeGraph graph(config); @@ -2993,7 +2994,6 @@ void test_transpose_view_mm( graph.prepare(); graph.encode_prepack(); graph.prepack(); - graph.encode_execute(); for (int i = 1; i < 4; i++) { float val_mat1 = i;