Skip to content
Merged
9 changes: 7 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,18 @@ void add_q_8w_linear_node(
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
ValueRef mat1_W_packed = mat1;
ValueRef out_W_packed = out;
// Create temporary tensors to store the width packed versions of mat1 and out
TmpTensor mat1_tmp(
&graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked);
TmpTensor out_tmp(
&graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked);
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
// Ensure mat1 is width packed
mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
mat1_W_packed = mat1_tmp;
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
// Ensure out is packed correctly
out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
out_W_packed = out_tmp;
}
ValueRef q_mat2 = prepack_standard(
graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked);
Expand Down