Skip to content

Commit cddb993

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Use TmpTensor for MatMul op.
Summary: This diff introduces the use of temporary tensors to reduce memory usage in the width packed versions of the matmul op shader. Differential Revision: D68924743
1 parent 3413971 commit cddb993

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ void add_matmul_optimized_node(
166166
/*passthrough = */ true);
167167

168168
// Ensure mat1 is width packed
169-
ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
169+
TmpTensor mat1_tmp(&graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked);
170+
ValueRef mat1_W_packed = mat1_tmp;
170171
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
171172
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
172173

@@ -176,8 +177,9 @@ void add_matmul_optimized_node(
176177
ValueRef mat2_packed = mat2;
177178
const utils::GPUMemoryLayout mat2_layout =
178179
mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked;
180+
TmpTensor mat2_tmp(&graph, graph.sizes_of(mat2), graph.dtype_of(mat2), mat2_layout);
179181
if (graph.estimate_memory_layout_of(mat2) != mat2_layout) {
180-
mat2_packed = graph.add_tensor_like(mat2, mat2_layout);
182+
mat2_packed = mat2_tmp;
181183
viewFn(graph, {mat2, graph.add_none(), mat2_packed});
182184
}
183185

0 commit comments

Comments
 (0)