Skip to content

Commit 377c463

Browse files
committed
[ET-VK][ez] Allow high dimensional tensors (for buffer storage)
ghstack-source-id: efd079a Pull Request resolved: #13596
1 parent e63a68d commit 377c463

File tree

4 files changed

+68
-41
lines changed

4 files changed

+68
-41
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,14 @@ utils::uvec3 calculate_image_extents(
189189
const std::vector<int64_t>& padded_sizes,
190190
const std::vector<int64_t>& axis_map,
191191
const int32_t packed_dim) {
192-
VK_CHECK_COND(padded_sizes.size() == 4);
193-
VK_CHECK_COND(axis_map.size() == 4);
194-
195192
utils::uvec3 extents({1, 1, 1});
193+
194+
// For high dimensional tensors, buffer storage must be used. No need to
195+
// compute image extents in this case.
196+
if (padded_sizes.size() > 4) {
197+
return extents;
198+
}
199+
196200
// First three elements of axis_map indicate which (X,Y,Z) image axis the
197201
// width, height, and channels dim of the tensor maps to.
198202
for (int whcn_dim = 0; whcn_dim < 3; ++whcn_dim) {
@@ -576,12 +580,15 @@ vTensor::vTensor(
576580
sizes,
577581
dtype_,
578582
allocate_memory)) {
579-
uniform_data_ = std::make_shared<UniformData>(UniformData{
580-
numel_,
581-
sizes_,
582-
dim_order_,
583-
strides_,
584-
calculate_logical_limits(storage_->image_extents_, axis_map_)});
583+
// uniform_data_ only valid for low dim tensors
584+
if (sizes.size() <= 4) {
585+
uniform_data_ = std::make_shared<UniformData>(UniformData{
586+
numel_,
587+
sizes_,
588+
dim_order_,
589+
strides_,
590+
calculate_logical_limits(storage_->image_extents_, axis_map_)});
591+
}
585592

586593
VK_CHECK_COND(
587594
dim_order_is_valid(dim_order_), "computed dim order is invalid");
@@ -813,24 +820,29 @@ size_t vTensor::get_max_ubo_nbytes(const size_t nbytes_per_ubo) const {
813820
}
814821

815822
const vkapi::BufferBindInfo vTensor::sizes_ubo() {
823+
VK_CHECK_COND(sizes_.size() <= 4);
816824
return metadata_ubo_impl(&sizes_uniform_offset_, uniform_data_->sizes_v);
817825
}
818826

819827
const vkapi::BufferBindInfo vTensor::dim_order_ubo() {
828+
VK_CHECK_COND(sizes_.size() <= 4);
820829
return metadata_ubo_impl(
821830
&dim_order_uniform_offset_, uniform_data_->dim_order_v);
822831
}
823832

824833
const vkapi::BufferBindInfo vTensor::strides_ubo() {
834+
VK_CHECK_COND(sizes_.size() <= 4);
825835
return metadata_ubo_impl(&strides_uniform_offset, uniform_data_->strides_v);
826836
}
827837

828838
const vkapi::BufferBindInfo vTensor::logical_limits_ubo() {
839+
VK_CHECK_COND(sizes_.size() <= 4);
829840
return metadata_ubo_impl(
830841
&logical_limits_uniform_offset_, uniform_data_->logical_limits);
831842
}
832843

833844
const vkapi::BufferBindInfo vTensor::numel_ubo() {
845+
VK_CHECK_COND(sizes_.size() <= 4);
834846
return metadata_ubo_impl(&numel_uniform_offset_, uniform_data_->numel);
835847
}
836848

@@ -893,31 +905,33 @@ void vTensor::update_metadata() {
893905
strides_ = calculate_strides(sizes_, dim_order_);
894906

895907
// Update uniform data if it has been modified
896-
uniform_data_->numel = utils::safe_downcast<int32_t>(numel_);
897-
uniform_data_->sizes_v =
898-
flip_and_unsqueeze_ivec4(sizes_, kTensorSizes, numel_);
899-
uniform_data_->dim_order_v =
900-
flip_and_unsqueeze_ivec4(dim_order_, kTensorDimOrder, numel_);
901-
uniform_data_->strides_v =
902-
flip_and_unsqueeze_ivec4(strides_, kTensorStrides, numel_);
903-
uniform_data_->logical_limits.limits =
904-
calculate_logical_limits(sizes_, axis_map_, packed_dim_);
905-
906-
if (sizes_uniform_offset_ != kUniformOffsetUnset) {
907-
uniforms_.update(uniform_data_->sizes_v, sizes_uniform_offset_);
908-
}
909-
if (dim_order_uniform_offset_ != kUniformOffsetUnset) {
910-
uniforms_.update(uniform_data_->dim_order_v, dim_order_uniform_offset_);
911-
}
912-
if (strides_uniform_offset != kUniformOffsetUnset) {
913-
uniforms_.update(uniform_data_->strides_v, strides_uniform_offset);
914-
}
915-
if (numel_uniform_offset_ != kUniformOffsetUnset) {
916-
uniforms_.update(numel_, numel_uniform_offset_);
917-
}
918-
if (logical_limits_uniform_offset_ != kUniformOffsetUnset) {
919-
uniforms_.update(
920-
uniform_data_->logical_limits.limits, logical_limits_uniform_offset_);
908+
if (sizes_.size() <= 4) {
909+
uniform_data_->numel = utils::safe_downcast<int32_t>(numel_);
910+
uniform_data_->sizes_v =
911+
flip_and_unsqueeze_ivec4(sizes_, kTensorSizes, numel_);
912+
uniform_data_->dim_order_v =
913+
flip_and_unsqueeze_ivec4(dim_order_, kTensorDimOrder, numel_);
914+
uniform_data_->strides_v =
915+
flip_and_unsqueeze_ivec4(strides_, kTensorStrides, numel_);
916+
uniform_data_->logical_limits.limits =
917+
calculate_logical_limits(sizes_, axis_map_, packed_dim_);
918+
919+
if (sizes_uniform_offset_ != kUniformOffsetUnset) {
920+
uniforms_.update(uniform_data_->sizes_v, sizes_uniform_offset_);
921+
}
922+
if (dim_order_uniform_offset_ != kUniformOffsetUnset) {
923+
uniforms_.update(uniform_data_->dim_order_v, dim_order_uniform_offset_);
924+
}
925+
if (strides_uniform_offset != kUniformOffsetUnset) {
926+
uniforms_.update(uniform_data_->strides_v, strides_uniform_offset);
927+
}
928+
if (numel_uniform_offset_ != kUniformOffsetUnset) {
929+
uniforms_.update(numel_, numel_uniform_offset_);
930+
}
931+
if (logical_limits_uniform_offset_ != kUniformOffsetUnset) {
932+
uniforms_.update(
933+
uniform_data_->logical_limits.limits, logical_limits_uniform_offset_);
934+
}
921935
}
922936

923937
if (buffer_meta_.buffer()) {

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ class vTensor final {
676676
}
677677

678678
const std::shared_ptr<UniformData>& get_uniform_data() const {
679+
VK_CHECK_COND(sizes_.size() <= 4);
679680
return uniform_data_;
680681
}
681682
};

backends/vulkan/test/op_tests/cases.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,28 @@ def get_binary_elementwise_inputs():
5555
((3, 64, 1), (1, 64, 1)),
5656
]
5757
)
58-
test_suite.layouts = [
59-
"utils::kWidthPacked",
60-
"utils::kChannelsPacked",
61-
]
6258
test_suite.storage_types = [
6359
"utils::kBuffer",
6460
"utils::kTexture3D",
6561
]
6662

67-
return test_suite
63+
highdim_test_suite = VkTestSuite(
64+
[
65+
((4, 5, 8, 1, 2, 1), (4, 5, 8, 1, 1, 1)),
66+
]
67+
)
68+
highdim_test_suite.storage_types = [
69+
"utils::kBuffer",
70+
]
71+
highdim_test_suite.test_name_suffix = "highdim"
72+
73+
for suite in [test_suite, highdim_test_suite]:
74+
suite.layouts = [
75+
"utils::kWidthPacked",
76+
"utils::kChannelsPacked",
77+
]
78+
79+
return [test_suite, highdim_test_suite]
6880

6981

7082
# Eq requires a different test generator so it was split from the other test case.

backends/vulkan/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,9 +599,9 @@ def make_filtered_tensor_repset(
599599
if extents_are_valid(extents, texture_limits):
600600
valid_texture_layouts.add(memory_layout)
601601

602-
# High dimensional tensors are currently not supported
602+
# High dimensional tensors require buffer storage
603603
if len(tensor_val.shape) > 4:
604-
return NO_STORAGE
604+
return TensorRepSet(tensor_repset.valid_buffer_layouts, set())
605605

606606
# Bool tensors are currently not supported
607607
if tensor_val.dtype == torch.bool:

0 commit comments

Comments
 (0)