diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index a78ac0519c4..1042c23bcb3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -73,13 +73,18 @@ void add_q_8w_linear_node( auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); ValueRef mat1_W_packed = mat1; ValueRef out_W_packed = out; + // Create temporary tensors to store the width packed versions of mat1 and out + TmpTensor mat1_tmp( + &graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked); + TmpTensor out_tmp( + &graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked); if (!graph.is_buffer_storage(out) && graph.packed_dim_of(mat1) != WHCN::kWidthDim) { // Ensure mat1 is width packed - mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); + mat1_W_packed = mat1_tmp; viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); // Ensure out is packed correctly - out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked); + out_W_packed = out_tmp; } ValueRef q_mat2 = prepack_standard( graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked);