@@ -166,18 +166,25 @@ void add_matmul_optimized_node(
166166 /* passthrough = */ true );
167167
168168 // Ensure mat1 is width packed
169- ValueRef mat1_W_packed = graph.add_tensor_like (mat1, utils::kWidthPacked );
169+ TmpTensor mat1_tmp (
170+ &graph, graph.sizes_of (mat1), graph.dtype_of (mat1), utils::kWidthPacked );
171+ ValueRef mat1_W_packed = mat1;
170172 auto viewFn = VK_GET_OP_FN (" aten.view_copy.default" );
171- viewFn (graph, {mat1, graph.add_none (), mat1_W_packed});
173+ if (graph.packed_dim_of (mat1) != WHCN::kWidthDim ) {
174+ mat1_W_packed = mat1_tmp;
175+ viewFn (graph, {mat1, graph.add_none (), mat1_W_packed});
176+ }
172177
173178 const bool mat2_is_transposed_val = graph.get_bool (mat2_is_transposed);
174179
175180 // Ensure mat2 to height packed
176181 ValueRef mat2_packed = mat2;
177182 const utils::GPUMemoryLayout mat2_layout =
178183 mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked ;
184+ TmpTensor mat2_tmp (
185+ &graph, graph.sizes_of (mat2), graph.dtype_of (mat2), mat2_layout);
179186 if (graph.estimate_memory_layout_of (mat2) != mat2_layout) {
180- mat2_packed = graph. add_tensor_like (mat2, mat2_layout) ;
187+ mat2_packed = mat2_tmp ;
181188 viewFn (graph, {mat2, graph.add_none (), mat2_packed});
182189 }
183190
0 commit comments