Skip to content

Commit e34d487

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK][ez] Allow high dimensional tensors (for buffer storage)"
Differential Revision: [D80800083](https://our.internmc.facebook.com/intern/diff/D80800083) [ghstack-poisoned]
1 parent dcd9fac commit e34d487

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -715,17 +715,17 @@ uint32_t vTensor::UniformData::write_attribute(
715715
}
716716

717717
vTensor::BufferMetadata::BufferMetadata(
718-
std::vector<int64_t> src_sizes,
719-
std::vector<int64_t> src_dim_order,
720-
std::vector<int64_t> src_strides,
718+
std::vector<int64_t>& src_sizes,
719+
std::vector<int64_t>& src_dim_order,
720+
std::vector<int64_t>& src_strides,
721721
size_t src_numel) {
722722
update(src_sizes, src_dim_order, src_strides, src_numel);
723723
}
724724

725725
void vTensor::BufferMetadata::update(
726-
std::vector<int64_t> src_sizes,
727-
std::vector<int64_t> src_dim_order,
728-
std::vector<int64_t> src_strides,
726+
std::vector<int64_t>& src_sizes,
727+
std::vector<int64_t>& src_dim_order,
728+
std::vector<int64_t>& src_strides,
729729
size_t src_numel) {
730730
int32_t fixed_ndim = utils::safe_downcast<int32_t>(kTensorDimLimit);
731731

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class vTensor final {
248248
UniformData(
249249
const size_t numel_ll,
250250
const std::vector<int64_t>& sizes,
251-
const std::vector<int64_t>& whcn_dim_order,
251+
const std::vector<int64_t>& dim_order,
252252
const std::vector<int64_t>& strides,
253253
const utils::uvec3& limits);
254254

@@ -272,15 +272,15 @@ class vTensor final {
272272
uint32_t numel;
273273

274274
BufferMetadata(
275-
std::vector<int64_t> sizes,
276-
std::vector<int64_t> dim_order,
277-
std::vector<int64_t> strides,
275+
std::vector<int64_t>& sizes,
276+
std::vector<int64_t>& dim_order,
277+
std::vector<int64_t>& strides,
278278
size_t numel);
279279

280280
void update(
281-
std::vector<int64_t> sizes,
282-
std::vector<int64_t> dim_order,
283-
std::vector<int64_t> strides,
281+
std::vector<int64_t>& sizes,
282+
std::vector<int64_t>& dim_order,
283+
std::vector<int64_t>& strides,
284284
size_t numel);
285285
};
286286

@@ -355,7 +355,7 @@ class vTensor final {
355355
ParamsBuffer uniforms_;
356356

357357
/*
358-
* TODO: explain
358+
* Used to store data for BufferMetadata to pass to shaders as buffer_meta_ubo
359359
*/
360360
ParamsBuffer buffer_meta_;
361361

0 commit comments

Comments
 (0)