Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions backends/vulkan/runtime/api/containers/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,7 @@ vTensor::vTensor(
sizes_,
whcn_dim_order,
unsqueezed_strides,
TextureLimits(
calculate_logical_limits(storage_->image_extents_, axis_map_)),
calculate_logical_limits(storage_->image_extents_, axis_map_),
numel_});
VK_CHECK_COND(
dim_order_is_valid(dim_order_), "computed dim order is invalid");
Expand Down Expand Up @@ -648,10 +647,12 @@ vTensor::vTensor(
uniforms_(),
// Construct Tensor storage
storage_(std::make_shared<vTensorStorage>(context, image)) {
TextureLimits logical_limits(
calculate_logical_limits(storage_->image_extents_, axis_map_));
uniform_data_ = std::make_shared<UniformData>(
UniformData{sizes_, {0, 0, 0, 0}, {0, 0, 0, 0}, logical_limits, numel_});
uniform_data_ = std::make_shared<UniformData>(UniformData{
sizes_,
{0, 0, 0, 0},
{0, 0, 0, 0},
calculate_logical_limits(storage_->image_extents_, axis_map_),
numel_});
}

vTensor::vTensor(vTensor& other)
Expand Down Expand Up @@ -698,7 +699,7 @@ vTensor::vTensor(
sizes_,
create_whcn_dim_order(dim_order_),
unsqueeze_strides(strides_, numel_),
{other.logical_limits()},
other.logical_limits(),
static_cast<size_t>(utils::multiply_integers(sizes_))});

VK_CHECK_COND(
Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/api/containers/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ class vTensor final {
// component vector with components of size N must have base alignment of
// 4N.
alignas(16) utils::ivec3 limits;

TextureLimits(const utils::uvec3& ulimits) : limits{ulimits} {}
};

public:
Expand Down Expand Up @@ -249,7 +251,7 @@ class vTensor final {
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& whcn_dim_order,
const std::vector<int64_t>& strides,
const TextureLimits& logical_limits,
const utils::uvec3& logical_limits,
const size_t numel_ll)
: sizes_v(utils::make_whcn_ivec4(sizes)),
whcn_dim_order_v(utils::make_ivec4(whcn_dim_order)),
Expand Down
Loading