@@ -71,46 +71,63 @@ void add_q_8w_linear_node(
7171 const ValueRef q_mat2_data,
7272 const ValueRef scales_data,
7373 const ValueRef out) {
74+ auto viewFn = VK_GET_OP_FN (" aten.view_copy.default" );
75+ ValueRef mat1_W_packed = mat1;
76+ ValueRef out_W_packed = out;
77+ if (!graph.is_buffer_storage (out) &&
78+ graph.packed_dim_of (mat1) != WHCN::kWidthDim ) {
79+ // Ensure mat1 is width packed
80+ mat1_W_packed = graph.add_tensor_like (mat1, utils::kWidthPacked );
81+ viewFn (graph, {mat1, graph.add_none (), mat1_W_packed});
82+ // Ensure out is packed correctly
83+ out_W_packed = graph.add_tensor_like (out, utils::kWidthPacked );
84+ }
7485 ValueRef q_mat2 =
7586 prepack_if_tensor_ref (graph, q_mat2_data, utils::kWidthPacked );
7687 ValueRef scales =
7788 prepack_if_tensor_ref (graph, scales_data, utils::kWidthPacked );
7889
7990 std::string kernel_name = " q_8w_linear" ;
8091 kernel_name.reserve (kShaderNameReserve );
81- add_packed_dim_suffix (kernel_name, graph.packed_dim_of (mat1 ));
92+ add_packed_dim_suffix (kernel_name, graph.packed_dim_of (mat1_W_packed ));
8293 add_packed_dim_suffix (kernel_name, graph.packed_dim_of (q_mat2));
83- add_dtype_suffix (kernel_name, graph.dtype_of (out ));
84- add_storage_type_suffix (kernel_name, graph.storage_type_of (out ));
94+ add_dtype_suffix (kernel_name, graph.dtype_of (out_W_packed ));
95+ add_storage_type_suffix (kernel_name, graph.storage_type_of (out_W_packed ));
8596
8697 vkapi::ParamsBindList ubos ({});
87- if (graph.is_buffer_storage (out )) {
98+ if (graph.is_buffer_storage (out_W_packed )) {
8899 ubos.append (
89- {graph.sizes_ubo (out ),
90- graph.strides_ubo (out ),
91- graph.numel_ubo (out ),
92- graph.sizes_ubo (mat1 ),
100+ {graph.sizes_ubo (out_W_packed ),
101+ graph.strides_ubo (out_W_packed ),
102+ graph.numel_ubo (out_W_packed ),
103+ graph.sizes_ubo (mat1_W_packed ),
93104 graph.strides_ubo (mat1),
94105 graph.strides_ubo (q_mat2),
95106 graph.strides_ubo (scales)});
96107 } else {
97- ubos.append ({graph.logical_limits_ubo (out), graph.sizes_ubo (mat1)});
108+ ubos.append (
109+ {graph.logical_limits_ubo (out_W_packed),
110+ graph.sizes_ubo (mat1_W_packed)});
98111 }
99112
100113 graph.execute_nodes ().emplace_back (new DispatchNode (
101114 graph,
102115 VK_KERNEL_FROM_STR (kernel_name),
103- graph.create_global_wg_size (out ),
104- graph.create_local_wg_size (out ),
116+ graph.create_global_wg_size (out_W_packed ),
117+ graph.create_local_wg_size (out_W_packed ),
105118 // Inputs and Outputs
106- {{out , vkapi::MemoryAccessType::WRITE},
107- {{mat1 , q_mat2, scales}, vkapi::MemoryAccessType::READ}},
119+ {{out_W_packed , vkapi::MemoryAccessType::WRITE},
120+ {{mat1_W_packed , q_mat2, scales}, vkapi::MemoryAccessType::READ}},
108121 // Shader params buffers
109122 ubos,
110123 // Specialization Constants
111124 {},
112125 // Resizing Logic
113126 resize_qlinear_node));
127+ if (!graph.is_buffer_storage (out) &&
128+ graph.packed_dim_of (out) != WHCN::kWidthDim ) {
129+ viewFn (graph, {out_W_packed, graph.add_none (), out});
130+ }
114131}
115132
116133void weight_int8pack_mm (
0 commit comments