Skip to content

Commit 865dc96

Browse files
committed
[ET-VK] Using TmpTensor for width packed versions of q_linear op shader to reduce memory usage.
This diff introduces the use of temporary tensors to reduce memory usage in the width packed versions of the q_linear op shader. Differential Revision: [D68561647](https://our.internmc.facebook.com/intern/diff/D68561647/) ghstack-source-id: 262822997 Pull Request resolved: #7929
1 parent c73d46d commit 865dc96

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,16 @@ void add_q_8w_linear_node(
7373
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
7474
ValueRef mat1_W_packed = mat1;
7575
ValueRef out_W_packed = out;
76+
// Create temporary tensors to store the width packed versions of mat1 and out
77+
TmpTensor mat1_tmp(&graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked);
78+
TmpTensor out_tmp(&graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked);
7679
if (!graph.is_buffer_storage(out) &&
7780
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
7881
// Ensure mat1 is width packed
79-
mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
82+
mat1_W_packed = mat1_tmp;
8083
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
8184
// Ensure out is packed correctly
82-
out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
85+
out_W_packed = out_tmp;
8386
}
8487
ValueRef q_mat2 = prepack_standard(
8588
graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked);

0 commit comments

Comments
 (0)