|
7 | 7 | */ |
8 | 8 |
|
9 | 9 | #include <executorch/backends/vulkan/runtime/api/containers/Tensor.h> |
| 10 | +#include <cassert> |
10 | 11 | #include <cstring> |
11 | 12 |
|
12 | 13 | namespace vkcompute { |
@@ -99,12 +100,31 @@ std::vector<int64_t> calculate_strides( |
99 | 100 | * |
100 | 101 | * The axis mapping allows for permuted views of texture-backed tensors. |
101 | 102 | */ |
102 | | -std::vector<int64_t> default_axis_map() { |
103 | | - // Currently, all compute shaders have an assumption that the channels dim is |
104 | | - // used to combine with the batch dim of a tensor. However, once dim mapping |
105 | | - // is integrated into the tensor indexing logic for each compute shader, we |
106 | | - // can be more flexible with mapping the batch dim to different texture axes |
107 | | - // in order to improve performance or memory footprint. |
| 103 | +std::vector<int64_t> calculate_axis_map( |
| 104 | + const std::vector<int64_t>& sizes, |
| 105 | + utils::AxisMapLayout axis_map_layout) { |
| 106 | + if (axis_map_layout == utils::AxisMapLayout::OPTIMIZED) { |
| 107 | + std::vector<int64_t> axis_map(sizes.size() + 1); |
| 108 | + std::iota(axis_map.begin(), axis_map.end() - 1, 0); |
| 109 | + |
| 110 | + std::stable_sort( |
| 111 | + axis_map.begin(), axis_map.end() - 1, [&sizes](size_t i1, size_t i2) { |
| 112 | + return sizes[i1] < sizes[i2]; |
| 113 | + }); |
| 114 | + |
| 115 | + assert(axis_map.size() > 0); |
| 116 | + // Find the index of the channel dimension |
| 117 | + for (size_t i = 0; i < axis_map.size() - 1; ++i) { |
| 118 | + assert(sizes.size() > axis_map[i]); |
| 119 | + if (sizes[axis_map[i]] == 2) { |
| 120 | + axis_map.back() = i; |
| 121 | + break; |
| 122 | + } |
| 123 | + } |
| 124 | + |
| 125 | + return axis_map; |
| 126 | + } |
| 127 | + // default |
108 | 128 | return {0, 1, 2, 2}; |
109 | 129 | } |
110 | 130 |
|
@@ -439,13 +459,14 @@ vTensor::vTensor( |
439 | 459 | const vkapi::ScalarType dtype, |
440 | 460 | const utils::StorageType storage_type, |
441 | 461 | const utils::GPUMemoryLayout memory_layout, |
442 | | - const bool allocate_memory) |
| 462 | + const bool allocate_memory, |
| 463 | + const utils::AxisMapLayout axis_map_layout) |
443 | 464 | : dtype_(dtype), |
444 | 465 | // Calculate tensor metadata |
445 | 466 | sizes_(sizes.begin(), sizes.end()), |
446 | 467 | packed_dim_(utils::to_packed_dim<int32_t>(memory_layout)), |
447 | 468 | dim_order_(calculate_dim_order(sizes_.size(), packed_dim_)), |
448 | | - axis_map_(default_axis_map()), |
| 469 | + axis_map_(calculate_axis_map(sizes_, axis_map_layout)), |
449 | 470 | strides_(calculate_strides(sizes, dim_order_)), |
450 | 471 | padded_sizes_{calculate_padded_sizes(sizes, packed_dim_)}, |
451 | 472 | unsqueezed_strides_{ |
@@ -484,13 +505,14 @@ vTensor::vTensor( |
484 | 505 | vTensor::vTensor( |
485 | 506 | Context* context, |
486 | 507 | const vkapi::VulkanImage& image, |
487 | | - const utils::GPUMemoryLayout memory_layout) |
| 508 | + const utils::GPUMemoryLayout memory_layout, |
| 509 | + const utils::AxisMapLayout axis_map_layout) |
488 | 510 | : dtype_(vkapi::element_scalartype(image.format())), |
489 | 511 | // Calculate tensor metadata |
490 | 512 | sizes_(calculate_sizes(image, memory_layout)), |
491 | 513 | packed_dim_(utils::to_packed_dim<int32_t>(memory_layout)), |
492 | 514 | dim_order_(), |
493 | | - axis_map_(default_axis_map()), |
| 515 | + axis_map_(calculate_axis_map(sizes_, axis_map_layout)), |
494 | 516 | strides_(), |
495 | 517 | padded_sizes_(calculate_padded_sizes(sizes_, packed_dim_)), |
496 | 518 | unsqueezed_strides_(), |
@@ -547,7 +569,7 @@ vTensor::vTensor( |
547 | 569 | sizes_(sizes.begin(), sizes.end()), |
548 | 570 | packed_dim_(other.packed_dim_), |
549 | 571 | dim_order_(dim_order.begin(), dim_order.end()), |
550 | | - axis_map_(default_axis_map()), |
| 572 | + axis_map_(calculate_axis_map(sizes_, utils::kDefaultAxisMap)), |
551 | 573 | strides_(calculate_strides(sizes_, dim_order_)), |
552 | 574 | padded_sizes_{calculate_padded_sizes(sizes, packed_dim_)}, |
553 | 575 | unsqueezed_strides_{ |
|
0 commit comments