Skip to content

Commit eff18ef

Browse files
committed
Update on "[ET-VK] Replace Uniform buffers with push constants for copy op"
This diff replaces uniform buffers with push constants for copy op in the Vulkan backend of Executorch. The changes include updating the GLSL code to use push constants instead of uniform buffers and updating the C++ code to pass the sizes as push constants to the shader. Differential Revision: [D66890851](https://our.internmc.facebook.com/intern/diff/D66890851/) [ghstack-poisoned]
2 parents 8c77b4d + 3511b07 commit eff18ef

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+449
-112
lines changed

.ci/scripts/setup-vulkan-linux-deps.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ install_swiftshader() {
2727

2828
install_vulkan_sdk() {
2929
VULKAN_SDK_VERSION=$1
30-
_vulkan_sdk_url="https://sdk.lunarg.com/sdk/download/${VULKAN_SDK_VERSION}/linux/vulkansdk-linux-x86_64-${VULKAN_SDK_VERSION}.tar.gz"
30+
_vulkan_sdk_url="https://sdk.lunarg.com/sdk/download/${VULKAN_SDK_VERSION}/linux/vulkansdk-linux-x86_64-${VULKAN_SDK_VERSION}.tar.xz"
3131

3232
_vulkan_sdk_dir=/tmp/vulkansdk
3333
mkdir -p $_vulkan_sdk_dir
@@ -37,12 +37,12 @@ install_vulkan_sdk() {
3737
curl --silent --show-error --location --fail --retry 3 \
3838
--output "${_tmp_archive}" "${_vulkan_sdk_url}"
3939

40-
tar -C "${_vulkan_sdk_dir}" -xzf "${_tmp_archive}"
40+
tar -C "${_vulkan_sdk_dir}" -xJf "${_tmp_archive}"
4141

4242
export PATH="${PATH}:${_vulkan_sdk_dir}/${VULKAN_SDK_VERSION}/x86_64/bin/"
4343
}
4444

45-
VULKAN_SDK_VERSION="1.2.198.1"
45+
VULKAN_SDK_VERSION="1.3.296.0"
4646

4747
install_swiftshader
4848
install_vulkan_sdk "${VULKAN_SDK_VERSION}"

backends/arm/quantizer/quantization_annotation/generic_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
torch.ops.aten.tile.default,
5454
torch.ops.aten.flip.default,
5555
torch.ops.aten.cat.default,
56+
torch.ops.aten.concatenate.default,
5657
torch.ops.aten.stack.default,
5758
torch.ops.aten.chunk.default,
5859
torch.ops.aten.contiguous.default,

backends/arm/test/ops/test_to_copy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def _test_to_copy_tosa_MI_pipeline(
5656
)
5757
.export()
5858
.dump_artifact()
59-
.check_count({"torch.ops.aten._to_copy.default": 1})
6059
.to_edge()
6160
.dump_artifact()
6261
.partition()

backends/arm/test/quantizer/test_generic_annotater.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,10 @@ def test_flip(self):
8686
self.check_annotation(
8787
SingleOpModel(torch.flip, (torch.randn(2, 4),), dims=(0, 1)),
8888
)
89+
90+
def test_concat(self):
91+
self.check_annotation(
92+
SingleOpModel(
93+
torch.concatenate, ((torch.randn(2, 3), torch.randn(2, 3)),), dim=0
94+
),
95+
)

backends/cadence/fusion_g3/operators/op_quantize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ Tensor& quantize_per_tensor_out(
570570
err == torch::executor::Error::Ok,
571571
"Failed to resize out Tensor in quantize_per_tensor_out");
572572

573-
check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
573+
// check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
574574

575575
float scale_data = (float)scale;
576576
int zero_point_data = (int)zero_point;
@@ -696,7 +696,7 @@ Tensor& quantize_per_channel_out(
696696
zero_point.numel(),
697697
input.size(axis));
698698

699-
check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
699+
// check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
700700

701701
const double* scale_dt = scale.const_data_ptr<double>();
702702
const int64_t* zero_point_dt = zero_point.const_data_ptr<int64_t>();

backends/vulkan/runtime/api/Context.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ void Context::register_shader_dispatch(
119119
const vkapi::DescriptorSet& descriptors,
120120
vkapi::PipelineBarrier& pipeline_barrier,
121121
const vkapi::ShaderInfo& shader_descriptor,
122-
const utils::uvec3& global_workgroup_size) {
122+
const utils::uvec3& global_workgroup_size,
123+
const void* push_constants_data,
124+
const uint32_t push_constants_size) {
123125
// Adjust the global workgroup size based on the output tile size
124126
uint32_t global_wg_w = utils::div_up(
125127
global_workgroup_size[0u], shader_descriptor.out_tile_size[0u]);
@@ -145,6 +147,15 @@ void Context::register_shader_dispatch(
145147
cmd_.bind_descriptors(descriptors.get_bind_handle());
146148
cmd_.insert_barrier(pipeline_barrier);
147149

150+
if (push_constants_size > 0 && push_constants_data != nullptr) {
151+
const VkDescriptorSetLayout shader_layout =
152+
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
153+
const VkPipelineLayout pipeline_layout =
154+
pipeline_layout_cache().retrieve(shader_layout);
155+
cmd_.set_push_constants(
156+
pipeline_layout, push_constants_data, push_constants_size);
157+
}
158+
148159
cmd_.dispatch(effective_global_wg);
149160
}
150161

backends/vulkan/runtime/api/Context.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ class Context final {
200200
const vkapi::DescriptorSet&,
201201
vkapi::PipelineBarrier&,
202202
const vkapi::ShaderInfo&,
203-
const utils::uvec3&);
203+
const utils::uvec3&,
204+
const void* = nullptr,
205+
const uint32_t = 0);
204206

205207
void register_blit(
206208
vkapi::PipelineBarrier&,

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ class StagingBuffer final {
2323
private:
2424
Context* context_p_;
2525
vkapi::ScalarType dtype_;
26-
size_t numel_;
27-
size_t nbytes_;
2826
vkapi::VulkanBuffer vulkan_buffer_;
2927

3028
void* mapped_data_;
@@ -36,10 +34,8 @@ class StagingBuffer final {
3634
const size_t numel)
3735
: context_p_(context_p),
3836
dtype_(dtype),
39-
numel_(numel),
40-
nbytes_(element_size(dtype_) * numel_),
41-
vulkan_buffer_(
42-
context_p_->adapter_ptr()->vma().create_staging_buffer(nbytes_)),
37+
vulkan_buffer_(context_p_->adapter_ptr()->vma().create_staging_buffer(
38+
element_size(dtype_) * numel)),
4339
mapped_data_(nullptr) {}
4440

4541
StagingBuffer(const StagingBuffer&) = delete;
@@ -68,15 +64,15 @@ class StagingBuffer final {
6864
}
6965

7066
inline size_t numel() {
71-
return numel_;
67+
return nbytes() / element_size(dtype_);
7268
}
7369

7470
inline size_t nbytes() {
75-
return nbytes_;
71+
return vulkan_buffer_.mem_size();
7672
}
7773

7874
inline void copy_from(const void* src, const size_t nbytes) {
79-
VK_CHECK_COND(nbytes <= nbytes_);
75+
VK_CHECK_COND(nbytes <= this->nbytes());
8076
memcpy(data(), src, nbytes);
8177
vmaFlushAllocation(
8278
vulkan_buffer_.vma_allocator(),
@@ -86,7 +82,7 @@ class StagingBuffer final {
8682
}
8783

8884
inline void copy_to(void* dst, const size_t nbytes) {
89-
VK_CHECK_COND(nbytes <= nbytes_);
85+
VK_CHECK_COND(nbytes <= this->nbytes());
9086
vmaInvalidateAllocation(
9187
vulkan_buffer_.vma_allocator(),
9288
vulkan_buffer_.allocation(),
@@ -96,7 +92,7 @@ class StagingBuffer final {
9692
}
9793

9894
inline void set_staging_zeros() {
99-
memset(data(), 0, nbytes_);
95+
memset(data(), 0, nbytes());
10096
}
10197
};
10298

0 commit comments

Comments
 (0)