From 3b95e14a286ccd9bd04aa1a8f58607072eae81e2 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Thu, 23 Jan 2025 22:17:59 -0800 Subject: [PATCH] [ET-VK] Using TmpTensor for width packed versions of q_linear op shader to reduce memory usage. This diff introduces the use of temporary tensors to reduce memory usage in the width packed versions of the q_linear op shader. Differential Revision: [D68561647](https://our.internmc.facebook.com/intern/diff/D68561647/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index a78ac0519c4..66520b631fd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -73,13 +73,16 @@ void add_q_8w_linear_node( auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); ValueRef mat1_W_packed = mat1; ValueRef out_W_packed = out; + // Create temporary tensors to store the width packed versions of mat1 and out + TmpTensor mat1_tmp(&graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked); + TmpTensor out_tmp(&graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked); if (!graph.is_buffer_storage(out) && graph.packed_dim_of(mat1) != WHCN::kWidthDim) { // Ensure mat1 is width packed - mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); + mat1_W_packed = mat1_tmp; viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); // Ensure out is packed correctly - out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked); + out_W_packed = out_tmp; } ValueRef q_mat2 = prepack_standard( graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked);