Skip to content

Commit 8fc5f59

Browse files
author
ssjia
committed
Update on "[ET-VK][ez] Allow high dimensional tensors (for buffer storage)"
Differential Revision: [D80800083](https://our.internmc.facebook.com/intern/diff/D80800083) [ghstack-poisoned]
2 parents a465dca + e34d487 commit 8fc5f59

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -722,17 +722,17 @@ uint32_t vTensor::UniformData::write_attribute(
722722
}
723723

724724
vTensor::BufferMetadata::BufferMetadata(
725-
std::vector<int64_t> src_sizes,
726-
std::vector<int64_t> src_dim_order,
727-
std::vector<int64_t> src_strides,
725+
std::vector<int64_t>& src_sizes,
726+
std::vector<int64_t>& src_dim_order,
727+
std::vector<int64_t>& src_strides,
728728
size_t src_numel) {
729729
update(src_sizes, src_dim_order, src_strides, src_numel);
730730
}
731731

732732
void vTensor::BufferMetadata::update(
733-
std::vector<int64_t> src_sizes,
734-
std::vector<int64_t> src_dim_order,
735-
std::vector<int64_t> src_strides,
733+
std::vector<int64_t>& src_sizes,
734+
std::vector<int64_t>& src_dim_order,
735+
std::vector<int64_t>& src_strides,
736736
size_t src_numel) {
737737
int32_t fixed_ndim = utils::safe_downcast<int32_t>(kTensorDimLimit);
738738

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class vTensor final {
248248
UniformData(
249249
const size_t numel_ll,
250250
const std::vector<int64_t>& sizes,
251-
const std::vector<int64_t>& whcn_dim_order,
251+
const std::vector<int64_t>& dim_order,
252252
const std::vector<int64_t>& strides,
253253
const utils::uvec3& limits);
254254

@@ -272,15 +272,15 @@ class vTensor final {
272272
uint32_t numel;
273273

274274
BufferMetadata(
275-
std::vector<int64_t> sizes,
276-
std::vector<int64_t> dim_order,
277-
std::vector<int64_t> strides,
275+
std::vector<int64_t>& sizes,
276+
std::vector<int64_t>& dim_order,
277+
std::vector<int64_t>& strides,
278278
size_t numel);
279279

280280
void update(
281-
std::vector<int64_t> sizes,
282-
std::vector<int64_t> dim_order,
283-
std::vector<int64_t> strides,
281+
std::vector<int64_t>& sizes,
282+
std::vector<int64_t>& dim_order,
283+
std::vector<int64_t>& strides,
284284
size_t numel);
285285
};
286286

@@ -355,7 +355,7 @@ class vTensor final {
355355
ParamsBuffer uniforms_;
356356

357357
/*
358-
* TODO: explain
358+
* Used to store data for BufferMetadata to pass to shaders as buffer_meta_ubo
359359
*/
360360
ParamsBuffer buffer_meta_;
361361

0 commit comments

Comments
 (0)