@@ -260,9 +260,6 @@ void check_q_4w_linear_args(
260260 const int group_size_val = graph.extract_scalar <int >(group_size);
261261 VK_CHECK_COND (K % group_size_val == 0 );
262262
263- VK_CHECK_COND (graph.packed_dim_of (mat1) == WHCN::kWidthDim );
264- VK_CHECK_COND (graph.packed_dim_of (out) == WHCN::kWidthDim );
265-
266263 VK_CHECK_COND (graph.has_standard_axis_map (mat1));
267264 VK_CHECK_COND (graph.has_standard_axis_map (out));
268265}
@@ -320,13 +317,32 @@ void add_q_4w_linear_node(
320317
321318 const uint32_t group_size_val = graph.extract_scalar <uint32_t >(group_size);
322319
320+ ValueRef mat1_W_packed = mat1;
321+ ValueRef out_W_packed = out;
322+ auto viewFn = VK_GET_OP_FN (" aten.view_copy.default" );
323+ // Create temporary tensors to store the width packed versions of mat1 and out
324+ TmpTensor mat1_tmp (
325+ &graph, graph.sizes_of (mat1), graph.dtype_of (mat1), utils::kWidthPacked );
326+ TmpTensor out_tmp (
327+ &graph, graph.sizes_of (out), graph.dtype_of (out), utils::kWidthPacked );
328+ if (storage_type == utils::kTexture3D ) {
329+ if (!graph.is_buffer_storage (out) &&
330+ graph.packed_dim_of (mat1) != WHCN::kWidthDim ) {
331+ // Ensure mat1 is width packed
332+ mat1_W_packed = mat1_tmp;
333+ viewFn (graph, {mat1, graph.add_none (), mat1_W_packed});
334+ // Ensure out is packed correctly
335+ out_W_packed = out_tmp;
336+ }
337+ }
338+
323339 vkapi::ParamsBindList ubos ({});
324- ubos.append (graph.logical_limits_ubo (out ));
325- ubos.append (graph.sizes_ubo (mat1 ));
340+ ubos.append (graph.logical_limits_ubo (out_W_packed ));
341+ ubos.append (graph.sizes_ubo (mat1_W_packed ));
326342 ubos.append (graph.strides_ubo (mat2));
327343 ubos.append (graph.strides_ubo (scales_and_zeros));
328344
329- utils::uvec3 global_wg_size = graph.logical_limits_of (out );
345+ utils::uvec3 global_wg_size = graph.logical_limits_of (out_W_packed );
330346 utils::uvec3 local_wg_size = graph.create_local_wg_size (global_wg_size);
331347
332348 graph.execute_nodes ().emplace_back (new DispatchNode (
@@ -335,15 +351,20 @@ void add_q_4w_linear_node(
335351 global_wg_size,
336352 local_wg_size,
337353 // Inputs and Outputs
338- {{out, vkapi::MemoryAccessType::WRITE},
339- {{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
354+ {{out_W_packed, vkapi::MemoryAccessType::WRITE},
355+ {{mat1_W_packed, mat2, scales_and_zeros},
356+ vkapi::MemoryAccessType::READ}},
340357 // Shader params buffers
341358 ubos,
342359 // Specialization Constants
343360 {SV (group_size_val)},
344361 // Resizing Logic
345362 resize_q_4w_linear_node,
346363 {}));
364+ if (!graph.is_buffer_storage (out) &&
365+ graph.packed_dim_of (out) != WHCN::kWidthDim ) {
366+ viewFn (graph, {out_W_packed, graph.add_none (), out});
367+ }
347368}
348369
349370void linear_weight_int4 (
0 commit comments