Skip to content

Commit 41a7f72

Browse files
committed
Update on "[ET-VK] Adding convenience functions in Compute graph to get PushConstantDataInfo for various attributes of a tensor."
This diff adds convenience functions in the Compute graph to get PushConstantDataInfo for various attributes of a tensor. Differential Revision: [D66853502](https://our.internmc.facebook.com/intern/diff/D66853502/) [ghstack-poisoned]
2 parents d7fa172 + 9bd96e9 commit 41a7f72

File tree

9 files changed

+67
-183
lines changed

9 files changed

+67
-183
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ void Context::register_shader_dispatch(
119119
const vkapi::DescriptorSet& descriptors,
120120
vkapi::PipelineBarrier& pipeline_barrier,
121121
const vkapi::ShaderInfo& shader_descriptor,
122-
const utils::uvec3& global_workgroup_size,
123-
const void* push_constants_data,
124-
const uint32_t push_constants_size) {
122+
const utils::uvec3& global_workgroup_size) {
125123
// Adjust the global workgroup size based on the output tile size
126124
uint32_t global_wg_w = utils::div_up(
127125
global_workgroup_size[0u], shader_descriptor.out_tile_size[0u]);
@@ -147,15 +145,6 @@ void Context::register_shader_dispatch(
147145
cmd_.bind_descriptors(descriptors.get_bind_handle());
148146
cmd_.insert_barrier(pipeline_barrier);
149147

150-
if (push_constants_size > 0 && push_constants_data != nullptr) {
151-
const VkDescriptorSetLayout shader_layout =
152-
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
153-
const VkPipelineLayout pipeline_layout =
154-
pipeline_layout_cache().retrieve(shader_layout);
155-
cmd_.set_push_constants(
156-
pipeline_layout, push_constants_data, push_constants_size);
157-
}
158-
159148
cmd_.dispatch(effective_global_wg);
160149
}
161150

backends/vulkan/runtime/api/Context.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,7 @@ class Context final {
200200
const vkapi::DescriptorSet&,
201201
vkapi::PipelineBarrier&,
202202
const vkapi::ShaderInfo&,
203-
const utils::uvec3&,
204-
const void* = nullptr,
205-
const uint32_t = 0);
203+
const utils::uvec3&);
206204

207205
void register_blit(
208206
vkapi::PipelineBarrier&,

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class StagingBuffer final {
2323
private:
2424
Context* context_p_;
2525
vkapi::ScalarType dtype_;
26+
size_t numel_;
27+
size_t nbytes_;
2628
vkapi::VulkanBuffer vulkan_buffer_;
2729

2830
void* mapped_data_;
@@ -34,8 +36,10 @@ class StagingBuffer final {
3436
const size_t numel)
3537
: context_p_(context_p),
3638
dtype_(dtype),
37-
vulkan_buffer_(context_p_->adapter_ptr()->vma().create_staging_buffer(
38-
element_size(dtype_) * numel)),
39+
numel_(numel),
40+
nbytes_(element_size(dtype_) * numel_),
41+
vulkan_buffer_(
42+
context_p_->adapter_ptr()->vma().create_staging_buffer(nbytes_)),
3943
mapped_data_(nullptr) {}
4044

4145
StagingBuffer(const StagingBuffer&) = delete;
@@ -64,15 +68,15 @@ class StagingBuffer final {
6468
}
6569

6670
inline size_t numel() {
67-
return nbytes() / element_size(dtype_);
71+
return numel_;
6872
}
6973

7074
inline size_t nbytes() {
71-
return vulkan_buffer_.mem_size();
75+
return nbytes_;
7276
}
7377

7478
inline void copy_from(const void* src, const size_t nbytes) {
75-
VK_CHECK_COND(nbytes <= this->nbytes());
79+
VK_CHECK_COND(nbytes <= nbytes_);
7680
memcpy(data(), src, nbytes);
7781
vmaFlushAllocation(
7882
vulkan_buffer_.vma_allocator(),
@@ -82,7 +86,7 @@ class StagingBuffer final {
8286
}
8387

8488
inline void copy_to(void* dst, const size_t nbytes) {
85-
VK_CHECK_COND(nbytes <= this->nbytes());
89+
VK_CHECK_COND(nbytes <= nbytes_);
8690
vmaInvalidateAllocation(
8791
vulkan_buffer_.vma_allocator(),
8892
vulkan_buffer_.allocation(),
@@ -92,7 +96,7 @@ class StagingBuffer final {
9296
}
9397

9498
inline void set_staging_zeros() {
95-
memset(data(), 0, nbytes());
99+
memset(data(), 0, nbytes_);
96100
}
97101
};
98102

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 27 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
*/
88

99
#include <executorch/backends/vulkan/runtime/api/containers/Tensor.h>
10-
#include <cstring>
1110

1211
namespace vkcompute {
1312
namespace api {
@@ -447,10 +446,11 @@ vTensor::vTensor(
447446
dim_order_(calculate_dim_order(sizes_.size(), packed_dim_)),
448447
axis_map_(default_axis_map()),
449448
strides_(calculate_strides(sizes, dim_order_)),
449+
numel_(utils::multiply_integers(sizes_)),
450450
padded_sizes_{calculate_padded_sizes(sizes, packed_dim_)},
451-
unsqueezed_strides_{
452-
unsqueeze_strides(strides_, utils::multiply_integers(sizes_))},
451+
unsqueezed_strides_{unsqueeze_strides(strides_, numel_)},
453452
padded_numel_(utils::multiply_integers(padded_sizes_)),
453+
logical_limits_{{0, 0, 0}},
454454
uniforms_(),
455455
// Utility Uniform Buffers that can be passed to shaders as arguments
456456
uniforms_size_(0),
@@ -467,11 +467,6 @@ vTensor::vTensor(
467467
padded_sizes_,
468468
dtype_,
469469
allocate_memory) {
470-
uniform_data_ = std::make_shared<UniformData>(UniformData{
471-
sizes_,
472-
unsqueezed_strides_,
473-
{{0, 0, 0}},
474-
static_cast<size_t>(utils::multiply_integers(sizes_))});
475470
VK_CHECK_COND(
476471
dim_order_is_valid(dim_order_), "computed dim order is invalid");
477472

@@ -499,9 +494,11 @@ vTensor::vTensor(
499494
dim_order_(),
500495
axis_map_(default_axis_map()),
501496
strides_(),
497+
numel_(utils::multiply_integers(sizes_)),
502498
padded_sizes_(calculate_padded_sizes(sizes_, packed_dim_)),
503499
unsqueezed_strides_(),
504500
padded_numel_(utils::multiply_integers(padded_sizes_)),
501+
logical_limits_(),
505502
uniforms_(),
506503
// Utility Uniform Buffers that can be passed to shaders as arguments
507504
uniforms_size_(0),
@@ -511,11 +508,6 @@ vTensor::vTensor(
511508
logical_limits_uniform_offset_(kUniformOffsetUnset),
512509
// Construct Tensor storage
513510
storage_(context, image) {
514-
uniform_data_ = std::make_shared<UniformData>(UniformData{
515-
sizes_,
516-
{0, 0, 0, 0},
517-
{{0, 0, 0}},
518-
static_cast<size_t>(utils::multiply_integers(sizes_))});
519511
set_logical_limits(storage_.image_extents_);
520512
}
521513

@@ -527,11 +519,13 @@ vTensor::vTensor(vTensor& other)
527519
dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
528520
axis_map_(other.axis_map_.begin(), other.axis_map_.end()),
529521
strides_(other.strides_.begin(), other.strides_.end()),
522+
numel_(other.numel_),
530523
padded_sizes_{other.padded_sizes_.begin(), other.padded_sizes_.end()},
531524
unsqueezed_strides_{
532525
other.unsqueezed_strides_.begin(),
533526
other.unsqueezed_strides_.end()},
534527
padded_numel_(other.padded_numel_),
528+
logical_limits_{other.logical_limits_},
535529
uniforms_(),
536530
// Empty initialize Utility Uniform Buffers
537531
uniforms_size_(0),
@@ -540,9 +534,7 @@ vTensor::vTensor(vTensor& other)
540534
numel_uniform_offset_(kUniformOffsetUnset),
541535
logical_limits_uniform_offset_(kUniformOffsetUnset),
542536
// Copy Tensor storage
543-
storage_(other.storage_) {
544-
uniform_data_ = std::make_shared<UniformData>(*other.get_uniform_data());
545-
}
537+
storage_(other.storage_) {}
546538

547539
vTensor::vTensor(
548540
vTensor& other,
@@ -556,10 +548,11 @@ vTensor::vTensor(
556548
dim_order_(dim_order.begin(), dim_order.end()),
557549
axis_map_(default_axis_map()),
558550
strides_(calculate_strides(sizes_, dim_order_)),
551+
numel_(utils::multiply_integers(sizes_)),
559552
padded_sizes_{calculate_padded_sizes(sizes, packed_dim_)},
560-
unsqueezed_strides_{
561-
unsqueeze_strides(strides_, utils::multiply_integers(sizes_))},
553+
unsqueezed_strides_{unsqueeze_strides(strides_, numel_)},
562554
padded_numel_(utils::multiply_integers(padded_sizes_)),
555+
logical_limits_(other.logical_limits_),
563556
uniforms_(),
564557
// Empty initialize Utility Uniform Buffers
565558
uniforms_size_(0),
@@ -569,45 +562,14 @@ vTensor::vTensor(
569562
logical_limits_uniform_offset_(kUniformOffsetUnset),
570563
// Copy Tensor storage
571564
storage_(other.storage_, vkapi::element_size(dtype_) * offset_numel) {
572-
uniform_data_ = std::make_shared<UniformData>(UniformData{
573-
sizes_,
574-
unsqueezed_strides_,
575-
{other.logical_limits()},
576-
static_cast<size_t>(utils::multiply_integers(sizes_))});
577-
578565
VK_CHECK_COND(
579566
dim_order_is_valid(dim_order_), "new dim order provided is invalid");
580567
VK_CHECK_COND(
581-
offset_numel + numel() <= other.numel(),
568+
offset_numel + numel_ <= other.numel(),
582569
"Tensor alias cannot access more elements than available in the original"
583570
"tensor");
584571
}
585572

586-
uint32_t vTensor::UniformData::write_attribute(
587-
void* dst,
588-
const uint32_t dst_offset,
589-
const uint32_t max_dst_size,
590-
const Attribute attr) {
591-
#define WRITE_ATTRIBUTE_CASE(enum_name, member_name) \
592-
case vTensor::Attribute::enum_name: { \
593-
VK_CHECK_COND( \
594-
(dst_offset + sizeof(member_name)) <= max_dst_size, \
595-
"Attempting to write tensor attribute outside data boundary."); \
596-
memcpy((uint8_t*)dst + dst_offset, &member_name, sizeof(member_name)); \
597-
return sizeof(member_name); \
598-
}
599-
switch (attr) {
600-
WRITE_ATTRIBUTE_CASE(SIZES, sizes_v);
601-
WRITE_ATTRIBUTE_CASE(STRIDES, strides_v);
602-
WRITE_ATTRIBUTE_CASE(LOGICAL_LIMITS, logical_limits);
603-
WRITE_ATTRIBUTE_CASE(NUMEL, numel);
604-
default:
605-
VK_THROW("Invalid Attribute");
606-
}
607-
#undef WRITE_ATTRIBUTE_CASE
608-
return 0;
609-
}
610-
611573
vkapi::VulkanImage& vTensor::image(
612574
vkapi::PipelineBarrier& pipeline_barrier,
613575
const vkapi::PipelineStageFlags stage) & {
@@ -639,9 +601,9 @@ vkapi::VulkanBuffer& vTensor::buffer(
639601
}
640602

641603
void vTensor::set_logical_limits(const utils::uvec3& image_extents) {
642-
uniform_data_->logical_limits.limits[0] = image_extents[axis_map_.at(0)];
643-
uniform_data_->logical_limits.limits[1] = image_extents[axis_map_.at(1)];
644-
uniform_data_->logical_limits.limits[2] = image_extents[axis_map_.at(2)];
604+
logical_limits_.limits[0] = image_extents[axis_map_.at(0)];
605+
logical_limits_.limits[1] = image_extents[axis_map_.at(1)];
606+
logical_limits_.limits[2] = image_extents[axis_map_.at(2)];
645607
}
646608

647609
utils::GPUMemoryLayout vTensor::estimate_memory_layout() const {
@@ -699,7 +661,7 @@ const vkapi::BufferBindInfo vTensor::logical_limits_ubo() {
699661
"Uniform data allocation has exceeded Tensor uniform buffer size");
700662
logical_limits_uniform_offset_ = uniforms_size_;
701663
uniforms_size_ += kSizePerUniform;
702-
uniforms_.update(logical_limits(), logical_limits_uniform_offset_);
664+
uniforms_.update(logical_limits_, logical_limits_uniform_offset_);
703665
}
704666
return vkapi::BufferBindInfo(
705667
uniforms_.buffer(), logical_limits_uniform_offset_);
@@ -715,7 +677,7 @@ const vkapi::BufferBindInfo vTensor::numel_ubo() {
715677
"Uniform data allocation has exceeded Tensor uniform buffer size");
716678
numel_uniform_offset_ = uniforms_size_;
717679
uniforms_size_ += kSizePerUniform;
718-
uniforms_.update(numel(), numel_uniform_offset_);
680+
uniforms_.update(numel_, numel_uniform_offset_);
719681
}
720682
return vkapi::BufferBindInfo(uniforms_.buffer(), numel_uniform_offset_);
721683
}
@@ -725,10 +687,10 @@ size_t vTensor::staging_buffer_numel() const {
725687
const bool int8_supported =
726688
storage_.context_->adapter_ptr()->has_full_int8_buffers_support();
727689
if (is_int8 && !int8_supported) {
728-
return utils::align_up_4(numel());
690+
return utils::align_up_4(numel_);
729691
}
730692
if (storage_type() == utils::kBuffer) {
731-
return numel();
693+
return numel_;
732694
}
733695
return padded_numel_;
734696
}
@@ -758,32 +720,30 @@ void vTensor::bind_allocation(const vkapi::Allocation& allocation) {
758720

759721
void vTensor::update_metadata() {
760722
strides_ = calculate_strides(sizes_, dim_order_);
761-
uniform_data_->numel = utils::multiply_integers(sizes_);
723+
numel_ = utils::multiply_integers(sizes_);
762724

763725
padded_sizes_ = calculate_padded_sizes(sizes_, packed_dim_);
764-
unsqueezed_strides_ = unsqueeze_strides(strides_, numel());
726+
unsqueezed_strides_ = unsqueeze_strides(strides_, numel_);
765727
padded_numel_ = utils::multiply_integers(padded_sizes_);
766728

767-
// Update uniform data if it has been modified
768-
uniform_data_->sizes_v = utils::make_whcn_ivec4(sizes_);
769-
uniform_data_->strides_v = utils::make_whcn_ivec4(unsqueezed_strides_);
770-
771729
// Calculate the image extents that would have been used to allocate a texture
772730
// withthe current sizes, and use that to set the logical limits.
773731
set_logical_limits(
774732
calculate_image_extents(padded_sizes_, axis_map_, packed_dim_));
775733

776734
if (sizes_uniform_offset_ != kUniformOffsetUnset) {
777-
uniforms_.update(uniform_data_->sizes_v, sizes_uniform_offset_);
735+
uniforms_.update(utils::make_whcn_ivec4(sizes_), sizes_uniform_offset_);
778736
}
779737
if (unsqueezed_strides_offset_ != kUniformOffsetUnset) {
780-
uniforms_.update(uniform_data_->strides_v, unsqueezed_strides_offset_);
738+
uniforms_.update(
739+
utils::make_whcn_ivec4(unsqueezed_strides_),
740+
unsqueezed_strides_offset_);
781741
}
782742
if (numel_uniform_offset_ != kUniformOffsetUnset) {
783-
uniforms_.update(numel(), numel_uniform_offset_);
743+
uniforms_.update(numel_, numel_uniform_offset_);
784744
}
785745
if (logical_limits_uniform_offset_ != kUniformOffsetUnset) {
786-
uniforms_.update(logical_limits(), logical_limits_uniform_offset_);
746+
uniforms_.update(logical_limits_, logical_limits_uniform_offset_);
787747
}
788748
}
789749

@@ -836,8 +796,6 @@ void vTensor::virtual_clone(const vTensor& other) {
836796
dim_order_ = other.dim_order_;
837797
axis_map_ = other.axis_map_;
838798
packed_dim_ = other.packed_dim_;
839-
840-
*uniform_data_ = *other.get_uniform_data();
841799
}
842800

843801
void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) {

0 commit comments

Comments
 (0)