From bc0facdcb3826ee22fc64c5e1f2776494007ea90 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 2 Jan 2025 13:05:26 -0800 Subject: [PATCH] [ET-VK][ez] Fix undefined behaviour in ambiguous `ParamsBuffer` constructor ## Context I discovered this bug when trying to execute the `vulkan_compute_api_test` binary on Windows. Almost all the tests were failing, with compute shaders producing incorrect results. After bisecting the change, it turns out the culprit is https://github.com/pytorch/executorch/pull/7015. The diff introduced an alternative templated constructor for `ParamsBuffer` which would initialize an empty UBO with a specified size instead of wrapping a pre-existing object. The issue is that these constructors are ambiguous because they both are template constructors and both only accept one argument. Therefore, the original constructor would be called when certain callsites intended to call the new constructor. This results in a UBO being created with an incorrect size, and resulted in the tensor's metadata being passed incorrectly into a compute shader. To fix, I added a dummy argument into the new constructor for disambiguation purposes. I also changed it so that it's not templated, since there's no reason for it to be templated. Differential Revision: [D67770791](https://our.internmc.facebook.com/intern/diff/D67770791/) [ghstack-poisoned] --- backends/vulkan/runtime/api/containers/ParamsBuffer.h | 5 +++-- backends/vulkan/runtime/api/containers/Tensor.cpp | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/backends/vulkan/runtime/api/containers/ParamsBuffer.h b/backends/vulkan/runtime/api/containers/ParamsBuffer.h index fe157c5e014..ecc07892cf7 100644 --- a/backends/vulkan/runtime/api/containers/ParamsBuffer.h +++ b/backends/vulkan/runtime/api/containers/ParamsBuffer.h @@ -31,8 +31,9 @@ class ParamsBuffer final { vulkan_buffer_( context_p_->adapter_ptr()->vma().create_params_buffer(block)) {} - template - ParamsBuffer(Context* context_p, const VkDeviceSize nbytes) + // The last bool argument, though unused, is required to disambiguate this + // constructor from the one above. + ParamsBuffer(Context* context_p, const VkDeviceSize nbytes, const bool unused) : context_p_(context_p), vulkan_buffer_( context_p_->adapter_ptr()->vma().create_uniform_buffer(nbytes)) {} diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 21b0ee4b176..92e310d36de 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -659,7 +659,7 @@ utils::GPUMemoryLayout vTensor::estimate_memory_layout() const { const vkapi::BufferBindInfo vTensor::sizes_ubo() { if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize); + uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true); } if (sizes_uniform_offset_ == kUniformOffsetUnset) { VK_CHECK_COND( @@ -674,7 +674,7 @@ const vkapi::BufferBindInfo vTensor::sizes_ubo() { const vkapi::BufferBindInfo vTensor::strides_ubo() { if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize); + uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true); } if (unsqueezed_strides_offset_ == kUniformOffsetUnset) { VK_CHECK_COND( @@ -691,7 +691,7 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() { const vkapi::BufferBindInfo vTensor::logical_limits_ubo() { if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize); + uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true); } if (logical_limits_uniform_offset_ == kUniformOffsetUnset) { VK_CHECK_COND( @@ -707,7 +707,7 @@ const vkapi::BufferBindInfo vTensor::logical_limits_ubo() { const vkapi::BufferBindInfo vTensor::numel_ubo() { if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize); + uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true); } if (numel_uniform_offset_ == kUniformOffsetUnset) { VK_CHECK_COND(