Skip to content

Commit a88e880

Browse files
authored
[ET-VK] Using TmpTensor for width packed versions of q_linear op shader to reduce memory usage.
Differential Revision: D68561647 Pull Request resolved: #7929
1 parent 25fd07d commit a88e880

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,18 @@ void add_q_8w_linear_node(
7373
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
7474
ValueRef mat1_W_packed = mat1;
7575
ValueRef out_W_packed = out;
76+
// Create temporary tensors to store the width packed versions of mat1 and out
77+
TmpTensor mat1_tmp(
78+
&graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked);
79+
TmpTensor out_tmp(
80+
&graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked);
7681
if (!graph.is_buffer_storage(out) &&
7782
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
7883
// Ensure mat1 is width packed
79-
mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
84+
mat1_W_packed = mat1_tmp;
8085
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
8186
// Ensure out is packed correctly
82-
out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
87+
out_W_packed = out_tmp;
8388
}
8489
ValueRef q_mat2 = prepack_standard(
8590
graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked);

0 commit comments

Comments
 (0)