@@ -189,10 +189,14 @@ utils::uvec3 calculate_image_extents(
189189 const std::vector<int64_t >& padded_sizes,
190190 const std::vector<int64_t >& axis_map,
191191 const int32_t packed_dim) {
192- VK_CHECK_COND (padded_sizes.size () == 4 );
193- VK_CHECK_COND (axis_map.size () == 4 );
194-
195192 utils::uvec3 extents ({1 , 1 , 1 });
193+
194+ // For high dimensional tensors, buffer storage must be used. No need to
195+ // compute image extents in this case.
196+ if (padded_sizes.size () > 4 ) {
197+ return extents;
198+ }
199+
196200 // First three elements of axis_map indicate which (X,Y,Z) image axis the
197201 // width, height, and channels dim of the tensor maps to.
198202 for (int whcn_dim = 0 ; whcn_dim < 3 ; ++whcn_dim) {
@@ -576,12 +580,15 @@ vTensor::vTensor(
576580 sizes,
577581 dtype_,
578582 allocate_memory)) {
579- uniform_data_ = std::make_shared<UniformData>(UniformData{
580- numel_,
581- sizes_,
582- dim_order_,
583- strides_,
584- calculate_logical_limits (storage_->image_extents_ , axis_map_)});
583+ // uniform_data_ only valid for low dim tensors
584+ if (sizes.size () <= 4 ) {
585+ uniform_data_ = std::make_shared<UniformData>(UniformData{
586+ numel_,
587+ sizes_,
588+ dim_order_,
589+ strides_,
590+ calculate_logical_limits (storage_->image_extents_ , axis_map_)});
591+ }
585592
586593 VK_CHECK_COND (
587594 dim_order_is_valid (dim_order_), " computed dim order is invalid" );
@@ -813,24 +820,29 @@ size_t vTensor::get_max_ubo_nbytes(const size_t nbytes_per_ubo) const {
813820}
814821
815822const vkapi::BufferBindInfo vTensor::sizes_ubo () {
823+ VK_CHECK_COND (sizes_.size () <= 4 );
816824 return metadata_ubo_impl (&sizes_uniform_offset_, uniform_data_->sizes_v );
817825}
818826
819827const vkapi::BufferBindInfo vTensor::dim_order_ubo () {
828+ VK_CHECK_COND (sizes_.size () <= 4 );
820829 return metadata_ubo_impl (
821830 &dim_order_uniform_offset_, uniform_data_->dim_order_v );
822831}
823832
824833const vkapi::BufferBindInfo vTensor::strides_ubo () {
834+ VK_CHECK_COND (sizes_.size () <= 4 );
825835 return metadata_ubo_impl (&strides_uniform_offset, uniform_data_->strides_v );
826836}
827837
828838const vkapi::BufferBindInfo vTensor::logical_limits_ubo () {
839+ VK_CHECK_COND (sizes_.size () <= 4 );
829840 return metadata_ubo_impl (
830841 &logical_limits_uniform_offset_, uniform_data_->logical_limits );
831842}
832843
833844const vkapi::BufferBindInfo vTensor::numel_ubo () {
845+ VK_CHECK_COND (sizes_.size () <= 4 );
834846 return metadata_ubo_impl (&numel_uniform_offset_, uniform_data_->numel );
835847}
836848
@@ -893,31 +905,33 @@ void vTensor::update_metadata() {
893905 strides_ = calculate_strides (sizes_, dim_order_);
894906
895907 // Update uniform data if it has been modified
896- uniform_data_->numel = utils::safe_downcast<int32_t >(numel_);
897- uniform_data_->sizes_v =
898- flip_and_unsqueeze_ivec4 (sizes_, kTensorSizes , numel_);
899- uniform_data_->dim_order_v =
900- flip_and_unsqueeze_ivec4 (dim_order_, kTensorDimOrder , numel_);
901- uniform_data_->strides_v =
902- flip_and_unsqueeze_ivec4 (strides_, kTensorStrides , numel_);
903- uniform_data_->logical_limits .limits =
904- calculate_logical_limits (sizes_, axis_map_, packed_dim_);
905-
906- if (sizes_uniform_offset_ != kUniformOffsetUnset ) {
907- uniforms_.update (uniform_data_->sizes_v , sizes_uniform_offset_);
908- }
909- if (dim_order_uniform_offset_ != kUniformOffsetUnset ) {
910- uniforms_.update (uniform_data_->dim_order_v , dim_order_uniform_offset_);
911- }
912- if (strides_uniform_offset != kUniformOffsetUnset ) {
913- uniforms_.update (uniform_data_->strides_v , strides_uniform_offset);
914- }
915- if (numel_uniform_offset_ != kUniformOffsetUnset ) {
916- uniforms_.update (numel_, numel_uniform_offset_);
917- }
918- if (logical_limits_uniform_offset_ != kUniformOffsetUnset ) {
919- uniforms_.update (
920- uniform_data_->logical_limits .limits , logical_limits_uniform_offset_);
908+ if (sizes_.size () <= 4 ) {
909+ uniform_data_->numel = utils::safe_downcast<int32_t >(numel_);
910+ uniform_data_->sizes_v =
911+ flip_and_unsqueeze_ivec4 (sizes_, kTensorSizes , numel_);
912+ uniform_data_->dim_order_v =
913+ flip_and_unsqueeze_ivec4 (dim_order_, kTensorDimOrder , numel_);
914+ uniform_data_->strides_v =
915+ flip_and_unsqueeze_ivec4 (strides_, kTensorStrides , numel_);
916+ uniform_data_->logical_limits .limits =
917+ calculate_logical_limits (sizes_, axis_map_, packed_dim_);
918+
919+ if (sizes_uniform_offset_ != kUniformOffsetUnset ) {
920+ uniforms_.update (uniform_data_->sizes_v , sizes_uniform_offset_);
921+ }
922+ if (dim_order_uniform_offset_ != kUniformOffsetUnset ) {
923+ uniforms_.update (uniform_data_->dim_order_v , dim_order_uniform_offset_);
924+ }
925+ if (strides_uniform_offset != kUniformOffsetUnset ) {
926+ uniforms_.update (uniform_data_->strides_v , strides_uniform_offset);
927+ }
928+ if (numel_uniform_offset_ != kUniformOffsetUnset ) {
929+ uniforms_.update (numel_, numel_uniform_offset_);
930+ }
931+ if (logical_limits_uniform_offset_ != kUniformOffsetUnset ) {
932+ uniforms_.update (
933+ uniform_data_->logical_limits .limits , logical_limits_uniform_offset_);
934+ }
921935 }
922936
923937 if (buffer_meta_.buffer ()) {
0 commit comments