@@ -189,10 +189,14 @@ utils::uvec3 calculate_image_extents(
189189 const std::vector<int64_t >& padded_sizes,
190190 const std::vector<int64_t >& axis_map,
191191 const int32_t packed_dim) {
192- VK_CHECK_COND (padded_sizes.size () == 4 );
193- VK_CHECK_COND (axis_map.size () == 4 );
194-
195192 utils::uvec3 extents ({1 , 1 , 1 });
193+
194+ // For high dimensional tensors, buffer storage must be used. No need to
195+ // compute image extents in this case.
196+ if (padded_sizes.size () > 4 ) {
197+ return extents;
198+ }
199+
196200 // First three elements of axis_map indicate which (X,Y,Z) image axis the
197201 // width, height, and channels dim of the tensor maps to.
198202 for (int whcn_dim = 0 ; whcn_dim < 3 ; ++whcn_dim) {
@@ -577,12 +581,15 @@ vTensor::vTensor(
577581 sizes,
578582 dtype_,
579583 allocate_memory)) {
580- uniform_data_ = std::make_shared<UniformData>(UniformData{
581- numel_,
582- sizes_,
583- dim_order_,
584- strides_,
585- calculate_logical_limits (storage_->image_extents_ , axis_map_)});
584+ // uniform_data_ only valid for low dim tensors
585+ if (sizes.size () <= 4 ) {
586+ uniform_data_ = std::make_shared<UniformData>(UniformData{
587+ numel_,
588+ sizes_,
589+ dim_order_,
590+ strides_,
591+ calculate_logical_limits (storage_->image_extents_ , axis_map_)});
592+ }
586593
587594 VK_CHECK_COND (
588595 dim_order_is_valid (dim_order_), " computed dim order is invalid" );
@@ -814,24 +821,29 @@ size_t vTensor::get_max_ubo_nbytes(const size_t nbytes_per_ubo) const {
814821}
815822
816823const vkapi::BufferBindInfo vTensor::sizes_ubo () {
824+ VK_CHECK_COND (sizes_.size () <= 4 );
817825 return metadata_ubo_impl (&sizes_uniform_offset_, uniform_data_->sizes_v );
818826}
819827
820828const vkapi::BufferBindInfo vTensor::dim_order_ubo () {
829+ VK_CHECK_COND (sizes_.size () <= 4 );
821830 return metadata_ubo_impl (
822831 &dim_order_uniform_offset_, uniform_data_->dim_order_v );
823832}
824833
825834const vkapi::BufferBindInfo vTensor::strides_ubo () {
835+ VK_CHECK_COND (sizes_.size () <= 4 );
826836 return metadata_ubo_impl (&strides_uniform_offset, uniform_data_->strides_v );
827837}
828838
829839const vkapi::BufferBindInfo vTensor::logical_limits_ubo () {
840+ VK_CHECK_COND (sizes_.size () <= 4 );
830841 return metadata_ubo_impl (
831842 &logical_limits_uniform_offset_, uniform_data_->logical_limits );
832843}
833844
834845const vkapi::BufferBindInfo vTensor::numel_ubo () {
846+ VK_CHECK_COND (sizes_.size () <= 4 );
835847 return metadata_ubo_impl (&numel_uniform_offset_, uniform_data_->numel );
836848}
837849
@@ -894,31 +906,33 @@ void vTensor::update_metadata() {
894906 strides_ = calculate_strides (sizes_, dim_order_);
895907
896908 // Update uniform data if it has been modified
897- uniform_data_->numel = utils::safe_downcast<int32_t >(numel_);
898- uniform_data_->sizes_v =
899- flip_and_unsqueeze_ivec4 (sizes_, kTensorSizes , numel_);
900- uniform_data_->dim_order_v =
901- flip_and_unsqueeze_ivec4 (dim_order_, kTensorDimOrder , numel_);
902- uniform_data_->strides_v =
903- flip_and_unsqueeze_ivec4 (strides_, kTensorStrides , numel_);
904- uniform_data_->logical_limits .limits =
905- calculate_logical_limits (sizes_, axis_map_, packed_dim_);
906-
907- if (sizes_uniform_offset_ != kUniformOffsetUnset ) {
908- uniforms_.update (uniform_data_->sizes_v , sizes_uniform_offset_);
909- }
910- if (dim_order_uniform_offset_ != kUniformOffsetUnset ) {
911- uniforms_.update (uniform_data_->dim_order_v , dim_order_uniform_offset_);
912- }
913- if (strides_uniform_offset != kUniformOffsetUnset ) {
914- uniforms_.update (uniform_data_->strides_v , strides_uniform_offset);
915- }
916- if (numel_uniform_offset_ != kUniformOffsetUnset ) {
917- uniforms_.update (numel_, numel_uniform_offset_);
918- }
919- if (logical_limits_uniform_offset_ != kUniformOffsetUnset ) {
920- uniforms_.update (
921- uniform_data_->logical_limits .limits , logical_limits_uniform_offset_);
909+ if (sizes_.size () <= 4 ) {
910+ uniform_data_->numel = utils::safe_downcast<int32_t >(numel_);
911+ uniform_data_->sizes_v =
912+ flip_and_unsqueeze_ivec4 (sizes_, kTensorSizes , numel_);
913+ uniform_data_->dim_order_v =
914+ flip_and_unsqueeze_ivec4 (dim_order_, kTensorDimOrder , numel_);
915+ uniform_data_->strides_v =
916+ flip_and_unsqueeze_ivec4 (strides_, kTensorStrides , numel_);
917+ uniform_data_->logical_limits .limits =
918+ calculate_logical_limits (sizes_, axis_map_, packed_dim_);
919+
920+ if (sizes_uniform_offset_ != kUniformOffsetUnset ) {
921+ uniforms_.update (uniform_data_->sizes_v , sizes_uniform_offset_);
922+ }
923+ if (dim_order_uniform_offset_ != kUniformOffsetUnset ) {
924+ uniforms_.update (uniform_data_->dim_order_v , dim_order_uniform_offset_);
925+ }
926+ if (strides_uniform_offset != kUniformOffsetUnset ) {
927+ uniforms_.update (uniform_data_->strides_v , strides_uniform_offset);
928+ }
929+ if (numel_uniform_offset_ != kUniformOffsetUnset ) {
930+ uniforms_.update (numel_, numel_uniform_offset_);
931+ }
932+ if (logical_limits_uniform_offset_ != kUniformOffsetUnset ) {
933+ uniforms_.update (
934+ uniform_data_->logical_limits .limits , logical_limits_uniform_offset_);
935+ }
922936 }
923937
924938 if (buffer_meta_.buffer ()) {
0 commit comments