@@ -89,11 +89,11 @@ std::vector<int64_t> calculate_strides(
8989 * tensor. Thus the axis mapping can be considered to be in WHCN dimension
9090 * order.
9191 *
92- * The last value `axis_mapping .at(3)` indicates the WHCN index of the tensor
92+ * The last value `axis_map .at(3)` indicates the WHCN index of the tensor
9393 * dimension along which batches will be concatenated. This dimension can be
9494 * referred to as the "inner dimension" To determine which image texture axis is
9595 * used for the concatenation, a double lookup will need to be performed
96- * (axis_mapping .at(axis_mapping .at(3))).
96+ * (axis_map .at(axis_map .at(3))).
9797 *
9898 * The reason for strucuring axis mapping this way is because for the batch dim,
9999 * two things need to be easily derived:
@@ -107,7 +107,7 @@ std::vector<int64_t> calculate_strides(
107107 *
108108 * The axis mapping allows for permuted views of texture-backed tensors.
109109 */
110- std::vector<int64_t > default_axis_mapping () {
110+ std::vector<int64_t > default_axis_map () {
111111 // Currently, all compute shaders have an assumption that the channels dim is
112112 // used to combine with the batch dim of a tensor. However, once dim mapping
113113 // is integrated into the tensor indexing logic for each compute shader, we
@@ -173,40 +173,40 @@ std::vector<int64_t> calculate_padded_sizes(
173173
174174utils::uvec3 calculate_image_extents (
175175 const std::vector<int64_t >& padded_sizes,
176- const std::vector<int64_t >& axis_mapping ,
176+ const std::vector<int64_t >& axis_map ,
177177 const utils::GPUMemoryLayout memory_layout) {
178178 VK_CHECK_COND (padded_sizes.size () == 4 );
179- VK_CHECK_COND (axis_mapping .size () == 4 );
179+ VK_CHECK_COND (axis_map .size () == 4 );
180180
181181 utils::uvec3 extents ({1 , 1 , 1 });
182- // First three elements of axis_mapping indicate which (X,Y,Z) image axis the
182+ // First three elements of axis_map indicate which (X,Y,Z) image axis the
183183 // width, height, and channels dim of the tensor maps to.
184184 for (int whcn_dim = 0 ; whcn_dim < 3 ; ++whcn_dim) {
185- const int64_t axis = axis_mapping .at (whcn_dim);
185+ const int64_t axis = axis_map .at (whcn_dim);
186186 const int64_t dim = padded_sizes.size () - 1 - whcn_dim;
187187 extents[axis] = utils::safe_downcast<uint32_t >(padded_sizes.at (dim));
188188 }
189189
190- // axis_mapping [3] indicates the WHCN index of the dimension used for batch
190+ // axis_map [3] indicates the WHCN index of the dimension used for batch
191191 // concatenation. Thus a double lookup is required to determine the image axis
192192 // used for batch concatenation.
193- const int64_t concatted_whcn_dim = axis_mapping .at (3 );
194- const int64_t batch_axis = axis_mapping .at (concatted_whcn_dim);
193+ const int64_t concatted_whcn_dim = axis_map .at (3 );
194+ const int64_t batch_axis = axis_map .at (concatted_whcn_dim);
195195 // Multiply the extents of the batch axis by the batch size.
196196 extents[batch_axis] *= padded_sizes.at (0 );
197197
198198 switch (memory_layout) {
199199 case utils::kWidthPacked :
200- VK_CHECK_COND (extents[0 ] % 4 == 0 );
201- extents[0 ] /= 4 ;
200+ VK_CHECK_COND (extents[axis_map. at ( 0 ) ] % 4 == 0 );
201+ extents[axis_map. at ( 0 ) ] /= 4 ;
202202 break ;
203203 case utils::kHeightPacked :
204- VK_CHECK_COND (extents[1 ] % 4 == 0 );
205- extents[1 ] /= 4 ;
204+ VK_CHECK_COND (extents[axis_map. at ( 1 ) ] % 4 == 0 );
205+ extents[axis_map. at ( 1 ) ] /= 4 ;
206206 break ;
207207 case utils::kChannelsPacked :
208- VK_CHECK_COND (extents[2 ] % 4 == 0 );
209- extents[2 ] /= 4 ;
208+ VK_CHECK_COND (extents[axis_map. at ( 2 ) ] % 4 == 0 );
209+ extents[axis_map. at ( 2 ) ] /= 4 ;
210210 break ;
211211 }
212212
@@ -229,25 +229,27 @@ vTensor::vTensor(
229229 // Calculate tensor metadata
230230 sizes_(sizes.begin(), sizes.end()),
231231 dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)),
232- axis_mapping_(default_axis_mapping ()),
232+ axis_map_(default_axis_map ()),
233233 strides_(calculate_strides(sizes, dim_order_)),
234234 numel_(utils::multiply_integers(sizes_)),
235235 padded_sizes_{calculate_padded_sizes (sizes, memory_layout_)},
236236 unsqueezed_strides_{unsqueeze_strides (strides_, numel_)},
237237 padded_numel_ (utils::multiply_integers(padded_sizes_)),
238238 texture_limits_{{0 , 0 , 0 }},
239+ logical_limits_{{0 , 0 , 0 }},
239240 // Utility Uniform Buffers that can be passed to shaders as arguments
240241 sizes_uniform_ (),
241242 strides_uniform_ (),
242243 numel_uniform_ (),
243- axis_mapping_uniform_ (),
244+ axis_map_uniform_ (),
244245 texture_limits_uniform_ (),
246+ logical_limits_uniform_ (),
245247 // Construct Tensor storage
246248 storage_ (
247249 context,
248250 storage_type,
249251 memory_layout_,
250- axis_mapping_ ,
252+ axis_map_ ,
251253 padded_sizes_,
252254 dtype_,
253255 allocate_memory) {
@@ -259,6 +261,8 @@ vTensor::vTensor(
259261 utils::safe_downcast<int32_t >(storage_.image_extents_ [0 ]),
260262 utils::safe_downcast<int32_t >(storage_.image_extents_ [1 ]),
261263 utils::safe_downcast<int32_t >(storage_.image_extents_ [2 ])};
264+
265+ update_logical_limits ();
262266 }
263267
264268 if (dtype == vkapi::kHalf ) {
@@ -275,7 +279,7 @@ vTensor::vTensor(const vTensor& other)
275279 // Copy tensor size metadata
276280 sizes_(other.sizes_.begin(), other.sizes_.end()),
277281 dim_order_(other.dim_order_.begin(), other.dim_order_.end()),
278- axis_mapping_ (other.axis_mapping_ .begin(), other.axis_mapping_ .end()),
282+ axis_map_ (other.axis_map_ .begin(), other.axis_map_ .end()),
279283 strides_(other.strides_.begin(), other.strides_.end()),
280284 numel_(other.numel_),
281285 padded_sizes_{other.padded_sizes_ .begin (), other.padded_sizes_ .end ()},
@@ -284,12 +288,14 @@ vTensor::vTensor(const vTensor& other)
284288 other.unsqueezed_strides_ .end ()},
285289 padded_numel_ (other.padded_numel_),
286290 texture_limits_{other.texture_limits_ },
291+ logical_limits_{other.logical_limits_ },
287292 // Empty initialize Utility Uniform Buffers
288293 sizes_uniform_ (),
289294 strides_uniform_ (),
290295 numel_uniform_ (),
291- axis_mapping_uniform_ (),
296+ axis_map_uniform_ (),
292297 texture_limits_uniform_ (),
298+ logical_limits_uniform_ (),
293299 // Copy Tensor storage
294300 storage_ (other.storage_) {}
295301
@@ -303,19 +309,21 @@ vTensor::vTensor(
303309 // Copy tensor size metadata
304310 sizes_(sizes.begin(), sizes.end()),
305311 dim_order_(dim_order.begin(), dim_order.end()),
306- axis_mapping_(default_axis_mapping ()),
312+ axis_map_(default_axis_map ()),
307313 strides_(calculate_strides(sizes_, dim_order_)),
308314 numel_(utils::multiply_integers(sizes_)),
309315 padded_sizes_{calculate_padded_sizes (sizes, memory_layout_)},
310316 unsqueezed_strides_{unsqueeze_strides (strides_, numel_)},
311317 padded_numel_ (utils::multiply_integers(padded_sizes_)),
312- texture_limits_{{0 , 0 , 0 }},
318+ texture_limits_{other.texture_limits_ },
319+ logical_limits_ (other.logical_limits_),
313320 // Empty initialize Utility Uniform Buffers
314321 sizes_uniform_ (),
315322 strides_uniform_ (),
316323 numel_uniform_ (),
317- axis_mapping_uniform_ (),
324+ axis_map_uniform_ (),
318325 texture_limits_uniform_ (),
326+ logical_limits_uniform_ (),
319327 // Copy Tensor storage
320328 storage_ (other.storage_, vkapi::element_size(dtype_) * offset_numel) {
321329 VK_CHECK_COND (
@@ -356,12 +364,18 @@ vkapi::VulkanBuffer& vTensor::buffer(
356364 return storage_.buffer_ ;
357365}
358366
359- utils::uvec3 vTensor::mapped_extents () const {
360- utils::uvec3 m_extents;
361- m_extents[0 ] = storage_.image_extents_ [axis_mapping_.at (0 )];
362- m_extents[1 ] = storage_.image_extents_ [axis_mapping_.at (1 )];
363- m_extents[2 ] = storage_.image_extents_ [axis_mapping_.at (2 )];
364- return m_extents;
367+ void vTensor::update_logical_limits () {
368+ logical_limits_.limits [0 ] = texture_limits_.limits [axis_map_.at (0 )];
369+ logical_limits_.limits [1 ] = texture_limits_.limits [axis_map_.at (1 )];
370+ logical_limits_.limits [2 ] = texture_limits_.limits [axis_map_.at (2 )];
371+ }
372+
373+ utils::uvec3 vTensor::logical_extents () const {
374+ utils::uvec3 logical_extents (
375+ {utils::safe_downcast<uint32_t >(logical_limits_.limits [0 ]),
376+ utils::safe_downcast<uint32_t >(logical_limits_.limits [1 ]),
377+ utils::safe_downcast<uint32_t >(logical_limits_.limits [2 ])});
378+ return logical_extents;
365379}
366380
367381const vkapi::BufferBindInfo vTensor::sizes_ubo () {
@@ -380,12 +394,12 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() {
380394 return vkapi::BufferBindInfo (strides_uniform_.buffer ());
381395}
382396
383- const vkapi::BufferBindInfo vTensor::axis_mapping_ubo () {
384- if (!axis_mapping_uniform_ .buffer ()) {
385- axis_mapping_uniform_ =
386- ParamsBuffer (storage_.context_ , utils::make_ivec4 (axis_mapping_ ));
397+ const vkapi::BufferBindInfo vTensor::axis_map_ubo () {
398+ if (!axis_map_uniform_ .buffer ()) {
399+ axis_map_uniform_ =
400+ ParamsBuffer (storage_.context_ , utils::make_ivec4 (axis_map_ ));
387401 }
388- return vkapi::BufferBindInfo (axis_mapping_uniform_ .buffer ());
402+ return vkapi::BufferBindInfo (axis_map_uniform_ .buffer ());
389403}
390404
391405const vkapi::BufferBindInfo vTensor::texture_limits_ubo () {
@@ -395,6 +409,13 @@ const vkapi::BufferBindInfo vTensor::texture_limits_ubo() {
395409 return vkapi::BufferBindInfo (texture_limits_uniform_.buffer ());
396410}
397411
412+ const vkapi::BufferBindInfo vTensor::logical_limits_ubo () {
413+ if (!logical_limits_uniform_.buffer ()) {
414+ logical_limits_uniform_ = ParamsBuffer (storage_.context_ , logical_limits_);
415+ }
416+ return vkapi::BufferBindInfo (logical_limits_uniform_.buffer ());
417+ }
418+
398419const vkapi::BufferBindInfo vTensor::numel_ubo () {
399420 if (!numel_uniform_.buffer ()) {
400421 numel_uniform_ = ParamsBuffer (storage_.context_ , numel_);
@@ -465,14 +486,16 @@ void vTensor::update_metadata() {
465486 // Calculate the extents of the image texture that would have been required
466487 // for a tensor of the new sizes.
467488 utils::uvec3 virtual_extents =
468- calculate_image_extents (padded_sizes_, axis_mapping_ , memory_layout_);
489+ calculate_image_extents (padded_sizes_, axis_map_ , memory_layout_);
469490
470491 // Update the texture limits to reflect the new virtual extents.
471492 texture_limits_.limits = utils::ivec3{
472493 utils::safe_downcast<int32_t >(virtual_extents[0 ]),
473494 utils::safe_downcast<int32_t >(virtual_extents[1 ]),
474495 utils::safe_downcast<int32_t >(virtual_extents[2 ])};
475496
497+ update_logical_limits ();
498+
476499 if (sizes_uniform_.buffer ()) {
477500 sizes_uniform_.update (utils::make_whcn_ivec4 (sizes_));
478501 }
@@ -482,20 +505,23 @@ void vTensor::update_metadata() {
482505 if (numel_uniform_.buffer ()) {
483506 numel_uniform_.update (numel_);
484507 }
485- if (axis_mapping_uniform_ .buffer ()) {
486- axis_mapping_uniform_ .update (utils::make_ivec4 (axis_mapping_ ));
508+ if (axis_map_uniform_ .buffer ()) {
509+ axis_map_uniform_ .update (utils::make_ivec4 (axis_map_ ));
487510 }
488511 if (texture_limits_uniform_.buffer ()) {
489512 texture_limits_uniform_.update (texture_limits_);
490513 }
514+ if (logical_limits_uniform_.buffer ()) {
515+ logical_limits_uniform_.update (logical_limits_);
516+ }
491517}
492518
493519void vTensor::check_sizes (const std::vector<int64_t >& sizes) const {
494520 if (storage_type () != utils::kBuffer ) {
495521 // For texture storage check that the current texture is large enough for
496522 // the new sizes of the tensor.
497523 utils::uvec3 virtual_extents =
498- calculate_image_extents (padded_sizes_, axis_mapping_ , memory_layout_);
524+ calculate_image_extents (padded_sizes_, axis_map_ , memory_layout_);
499525
500526 bool valid_resize = virtual_extents[0 ] <= image_extents ()[0 ];
501527 valid_resize = valid_resize && virtual_extents[1 ] <= image_extents ()[1 ];
@@ -546,7 +572,7 @@ void vTensor::reallocate(const std::vector<int64_t>& new_sizes) {
546572 update_metadata ();
547573 storage_.discard_and_reallocate (
548574 calculate_padded_sizes (new_sizes, memory_layout_),
549- axis_mapping_ ,
575+ axis_map_ ,
550576 memory_layout_,
551577 dtype_);
552578}
@@ -624,16 +650,14 @@ vTensorStorage::vTensorStorage(
624650 Context* const context,
625651 const utils::StorageType storage_type,
626652 const utils::GPUMemoryLayout gpu_memory_layout,
627- const std::vector<int64_t >& axis_mapping ,
653+ const std::vector<int64_t >& axis_map ,
628654 const std::vector<int64_t >& padded_sizes,
629655 const vkapi::ScalarType dtype,
630656 const bool allocate_memory)
631657 : context_(context),
632658 storage_type_{storage_type},
633- image_extents_ (calculate_image_extents(
634- padded_sizes,
635- axis_mapping,
636- gpu_memory_layout)),
659+ image_extents_ (
660+ calculate_image_extents (padded_sizes, axis_map, gpu_memory_layout)),
637661 buffer_length_{utils::multiply_integers (padded_sizes)},
638662 buffer_offset_{0 },
639663 image_ (allocate_image(
@@ -746,7 +770,7 @@ bool vTensorStorage::is_copy_of(const vTensorStorage& other) const {
746770
747771void vTensorStorage::discard_and_reallocate (
748772 const std::vector<int64_t >& padded_sizes,
749- const std::vector<int64_t >& axis_mapping ,
773+ const std::vector<int64_t >& axis_map ,
750774 const utils::GPUMemoryLayout gpu_memory_layout,
751775 const vkapi::ScalarType dtype) {
752776 const bool image_owns_memory = image_.owns_memory ();
@@ -755,7 +779,7 @@ void vTensorStorage::discard_and_reallocate(
755779 flush ();
756780
757781 image_extents_ =
758- calculate_image_extents (padded_sizes, axis_mapping , gpu_memory_layout);
782+ calculate_image_extents (padded_sizes, axis_map , gpu_memory_layout);
759783 image_ = allocate_image (
760784 context_,
761785 image_extents_,
0 commit comments