@@ -71,46 +71,59 @@ void add_q_8w_linear_node(
7171 const ValueRef q_mat2_data,
7272 const ValueRef scales_data,
7373 const ValueRef out) {
74+ auto viewFn = VK_GET_OP_FN (" aten.view_copy.default" );
75+ ValueRef mat1_W_packed = mat1;
76+ ValueRef out_W_packed = out;
77+ if (!graph.is_buffer_storage (out) && graph.packed_dim_of (mat1) != WHCN::kWidthDim ) {
78+ // Ensure mat1 is width packed
79+ mat1_W_packed = graph.add_tensor_like (mat1, utils::kWidthPacked );
80+ viewFn (graph, {mat1, graph.add_none (), mat1_W_packed});
81+ // Ensure out is packed correctly
82+ out_W_packed = graph.add_tensor_like (out, utils::kWidthPacked );
83+ }
7484 ValueRef q_mat2 =
7585 prepack_if_tensor_ref (graph, q_mat2_data, utils::kWidthPacked );
7686 ValueRef scales =
7787 prepack_if_tensor_ref (graph, scales_data, utils::kWidthPacked );
7888
7989 std::string kernel_name = " q_8w_linear" ;
8090 kernel_name.reserve (kShaderNameReserve );
81- add_packed_dim_suffix (kernel_name, graph.packed_dim_of (mat1 ));
91+ add_packed_dim_suffix (kernel_name, graph.packed_dim_of (mat1_W_packed ));
8292 add_packed_dim_suffix (kernel_name, graph.packed_dim_of (q_mat2));
83- add_dtype_suffix (kernel_name, graph.dtype_of (out ));
84- add_storage_type_suffix (kernel_name, graph.storage_type_of (out ));
93+ add_dtype_suffix (kernel_name, graph.dtype_of (out_W_packed ));
94+ add_storage_type_suffix (kernel_name, graph.storage_type_of (out_W_packed ));
8595
8696 vkapi::ParamsBindList ubos ({});
87- if (graph.is_buffer_storage (out )) {
97+ if (graph.is_buffer_storage (out_W_packed )) {
8898 ubos.append (
89- {graph.sizes_ubo (out ),
90- graph.strides_ubo (out ),
91- graph.numel_ubo (out ),
92- graph.sizes_ubo (mat1 ),
99+ {graph.sizes_ubo (out_W_packed ),
100+ graph.strides_ubo (out_W_packed ),
101+ graph.numel_ubo (out_W_packed ),
102+ graph.sizes_ubo (mat1_W_packed ),
93103 graph.strides_ubo (mat1),
94104 graph.strides_ubo (q_mat2),
95105 graph.strides_ubo (scales)});
96106 } else {
97- ubos.append ({graph.logical_limits_ubo (out ), graph.sizes_ubo (mat1 )});
107+ ubos.append ({graph.logical_limits_ubo (out_W_packed ), graph.sizes_ubo (mat1_W_packed )});
98108 }
99109
100110 graph.execute_nodes ().emplace_back (new DispatchNode (
101111 graph,
102112 VK_KERNEL_FROM_STR (kernel_name),
103- graph.create_global_wg_size (out ),
104- graph.create_local_wg_size (out ),
113+ graph.create_global_wg_size (out_W_packed ),
114+ graph.create_local_wg_size (out_W_packed ),
105115 // Inputs and Outputs
106- {{out , vkapi::MemoryAccessType::WRITE},
107- {{mat1 , q_mat2, scales}, vkapi::MemoryAccessType::READ}},
116+ {{out_W_packed , vkapi::MemoryAccessType::WRITE},
117+ {{mat1_W_packed , q_mat2, scales}, vkapi::MemoryAccessType::READ}},
108118 // Shader params buffers
109119 ubos,
110120 // Specialization Constants
111121 {},
112122 // Resizing Logic
113123 resize_qlinear_node));
124+ if (!graph.is_buffer_storage (out) && graph.packed_dim_of (out) != WHCN::kWidthDim ) {
125+ viewFn (graph, {out_W_packed, graph.add_none (), out});
126+ }
114127}
115128
116129void weight_int8pack_mm (
0 commit comments