@@ -260,9 +260,6 @@ void check_q_4w_linear_args(
260260  const  int  group_size_val = graph.extract_scalar <int >(group_size);
261261  VK_CHECK_COND (K % group_size_val == 0 );
262262
263-   VK_CHECK_COND (graph.packed_dim_of (mat1) == WHCN::kWidthDim );
264-   VK_CHECK_COND (graph.packed_dim_of (out) == WHCN::kWidthDim );
265- 
266263  VK_CHECK_COND (graph.has_standard_axis_map (mat1));
267264  VK_CHECK_COND (graph.has_standard_axis_map (out));
268265}
@@ -320,13 +317,32 @@ void add_q_4w_linear_node(
320317
321318  const  uint32_t  group_size_val = graph.extract_scalar <uint32_t >(group_size);
322319
320+   ValueRef mat1_W_packed = mat1;
321+   ValueRef out_W_packed = out;
322+   auto  viewFn = VK_GET_OP_FN (" aten.view_copy.default"  );
323+   //  Create temporary tensors to store the width packed versions of mat1 and out
324+   TmpTensor mat1_tmp (
325+       &graph, graph.sizes_of (mat1), graph.dtype_of (mat1), utils::kWidthPacked );
326+   TmpTensor out_tmp (
327+       &graph, graph.sizes_of (out), graph.dtype_of (out), utils::kWidthPacked );
328+   if  (storage_type == utils::kTexture3D ) {
329+     if  (!graph.is_buffer_storage (out) &&
330+         graph.packed_dim_of (mat1) != WHCN::kWidthDim ) {
331+       //  Ensure mat1 is width packed
332+       mat1_W_packed = mat1_tmp;
333+       viewFn (graph, {mat1, graph.add_none (), mat1_W_packed});
334+       //  Ensure out is packed correctly
335+       out_W_packed = out_tmp;
336+     }
337+   }
338+ 
323339  vkapi::ParamsBindList ubos ({});
324-   ubos.append (graph.logical_limits_ubo (out ));
325-   ubos.append (graph.sizes_ubo (mat1 ));
340+   ubos.append (graph.logical_limits_ubo (out_W_packed ));
341+   ubos.append (graph.sizes_ubo (mat1_W_packed ));
326342  ubos.append (graph.strides_ubo (mat2));
327343  ubos.append (graph.strides_ubo (scales_and_zeros));
328344
329-   utils::uvec3 global_wg_size = graph.logical_limits_of (out );
345+   utils::uvec3 global_wg_size = graph.logical_limits_of (out_W_packed );
330346  utils::uvec3 local_wg_size = graph.create_local_wg_size (global_wg_size);
331347
332348  graph.execute_nodes ().emplace_back (new  DispatchNode (
@@ -335,15 +351,19 @@ void add_q_4w_linear_node(
335351      global_wg_size,
336352      local_wg_size,
337353      //  Inputs and Outputs
338-       {{out , vkapi::MemoryAccessType::WRITE},
339-        {{mat1 , mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
354+       {{out_W_packed , vkapi::MemoryAccessType::WRITE},
355+        {{mat1_W_packed , mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
340356      //  Shader params buffers
341357      ubos,
342358      //  Specialization Constants
343359      {SV (group_size_val)},
344360      //  Resizing Logic
345361      resize_q_4w_linear_node,
346362      {}));
363+   if  (!graph.is_buffer_storage (out) &&
364+       graph.packed_dim_of (out) != WHCN::kWidthDim ) {
365+     viewFn (graph, {out_W_packed, graph.add_none (), out});
366+   }
347367}
348368
349369void  linear_weight_int4 (
0 commit comments