@@ -165,6 +165,7 @@ struct vk_device_struct {
165165    vk_queue transfer_queue;
166166    bool  single_queue;
167167    uint32_t  subgroup_size;
168+     uint32_t  shader_core_count;
168169    bool  uma;
169170
170171    size_t  idx;
@@ -1498,7 +1499,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
14981499    ggml_vk_create_pipeline (device, device->pipeline_get_rows_f32 [GGML_TYPE_Q8_0], " get_rows_q8_0_f32"  , get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, " main"  , 3 , sizeof (vk_op_binary_push_constants), {1024 , 1 , 1 }, {}, 1 );
14991500    ggml_vk_create_pipeline (device, device->pipeline_get_rows_f32 [GGML_TYPE_IQ4_NL], " get_rows_iq4_nl_f32"  , get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, " main"  , 3 , sizeof (vk_op_binary_push_constants), {1024 , 1 , 1 }, {}, 1 );
15001501
1501-     ggml_vk_create_pipeline (device, device->pipeline_matmul_split_k_reduce , " split_k_reduce"  , split_k_reduce_len, split_k_reduce_data, " main"  , 2 , 2  * sizeof (uint32_t ), {256 , 1 , 1 }, {}, 1 );
1502+     ggml_vk_create_pipeline (device, device->pipeline_matmul_split_k_reduce , " split_k_reduce"  , split_k_reduce_len, split_k_reduce_data, " main"  , 2 , 2  * sizeof (uint32_t ), {256  *  4 , 1 , 1 }, {}, 1 );
15021503
15031504    ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 , " mul_mat_vec_p021_f16_f32"  , mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main"  , 3 , 6  * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
15041505    ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_nc_f16_f32 , " mul_mat_vec_nc_f16_f32"  , mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, " main"  , 3 , 7  * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
@@ -1610,23 +1611,36 @@ static vk_device ggml_vk_get_device(size_t idx) {
16101611        const  std::vector<vk::ExtensionProperties> ext_props = device->physical_device .enumerateDeviceExtensionProperties ();
16111612
16121613        bool  maintenance4_support = false ;
1614+         bool  sm_builtins = false ;
16131615
16141616        //  Check if maintenance4 is supported
16151617        for  (const  auto & properties : ext_props) {
16161618            if  (strcmp (" VK_KHR_maintenance4"  , properties.extensionName ) == 0 ) {
16171619                maintenance4_support = true ;
1620+             } else  if  (strcmp (" VK_NV_shader_sm_builtins"  , properties.extensionName ) == 0 ) {
1621+                 sm_builtins = true ;
16181622            }
16191623        }
16201624
16211625        vk::PhysicalDeviceProperties2 props2;
16221626        vk::PhysicalDeviceMaintenance3Properties props3;
16231627        vk::PhysicalDeviceMaintenance4Properties props4;
16241628        vk::PhysicalDeviceSubgroupProperties subgroup_props;
1629+         vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
16251630        props2.pNext  = &props3;
16261631        props3.pNext  = &subgroup_props;
1632+ 
1633+         VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props;
1634+ 
16271635        if  (maintenance4_support) {
1628-             subgroup_props.pNext  = &props4;
1636+             last_struct->pNext  = (VkBaseOutStructure *)&props4;
1637+             last_struct = (VkBaseOutStructure *)&props4;
1638+         }
1639+         if  (sm_builtins) {
1640+             last_struct->pNext  = (VkBaseOutStructure *)&sm_props;
1641+             last_struct = (VkBaseOutStructure *)&sm_props;
16291642        }
1643+ 
16301644        device->physical_device .getProperties2 (&props2);
16311645        device->properties  = props2.properties ;
16321646
@@ -1643,6 +1657,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
16431657        device->vendor_id  = device->properties .vendorID ;
16441658        device->subgroup_size  = subgroup_props.subgroupSize ;
16451659        device->uma  = device->properties .deviceType  == vk::PhysicalDeviceType::eIntegratedGpu;
1660+         if  (sm_builtins) {
1661+             device->shader_core_count  = sm_props.shaderSMCount ;
1662+         } else  {
1663+             device->shader_core_count  = 0 ;
1664+         }
16461665
16471666        bool  fp16_storage = false ;
16481667        bool  fp16_compute = false ;
@@ -2732,15 +2751,25 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
27322751    dst->device ->device .resetFences ({ dst->device ->fence  });
27332752}
27342753
2735- static  uint32_t  ggml_vk_guess_split_k (int  m, int  n, int  k) {
2754+ static  uint32_t  ggml_vk_guess_split_k (ggml_backend_vk_context * ctx,  int  m, int  n, int  k,  const  vk_pipeline& pipeline ) {
27362755    VK_LOG_DEBUG (" ggml_vk_guess_split_k("   << m << " , "   << n << " , "   << k << " )"  );
2737-     //  if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
2738-     //      return 4;
2739-     //  }
27402756
2741-     return  1 ;
2757+     uint32_t  split_k = 1 ;
2758+     if  (ctx->device ->shader_core_count  != 0  && m >= (int )pipeline->wg_denoms [0 ] && n >= (int )pipeline->wg_denoms [1 ]) {
2759+         //  If k is 'large' and the SMs will fill less than halfway, use split_k.
2760+         uint32_t  m_tiles = CEIL_DIV (m, pipeline->wg_denoms [0 ]);
2761+         uint32_t  n_tiles = CEIL_DIV (n, pipeline->wg_denoms [1 ]);
2762+         if  (k >= 2048  && m_tiles * n_tiles < ctx->device ->shader_core_count  / 2 ) {
2763+             split_k = ctx->device ->shader_core_count  / (m_tiles * n_tiles);
2764+             //  Clamp to 2 or 4
2765+             split_k = std::min (split_k, 4u );
2766+             if  (split_k == 3 ) {
2767+                 split_k = 2 ;
2768+             }
2769+         }
2770+     }
27422771
2743-     GGML_UNUSED (m);  GGML_UNUSED (n);  GGML_UNUSED (k) ;
2772+     return  split_k ;
27442773}
27452774
27462775static  vk_pipeline ggml_vk_guess_matmul_pipeline_amd (ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int  m, int  n, bool  aligned) {
@@ -2964,10 +2993,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
29642993    const  uint32_t  kpad = ggml_vk_align_size (ne10, ggml_vk_guess_matmul_pipeline_align (ctx, mmp, ne01, ne11));
29652994    const  bool  aligned = ne10 == kpad && ne01 > 8  && ne11 > 8 ;
29662995
2967-     const  uint32_t  split_k = ggml_vk_guess_split_k (ne01, ne11, ne10);
2968- 
29692996    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline (ctx, mmp, ne01, ne11, aligned);
29702997
2998+     const  uint32_t  split_k = ggml_vk_guess_split_k (ctx, ne01, ne11, ne10, pipeline);
2999+ 
29713000    const  uint64_t  qx_sz = ggml_type_size (src0->type ) * x_ne / ggml_blck_size (src0->type );
29723001    const  uint64_t  qy_sz = ggml_type_size (src1->type ) * y_ne / ggml_blck_size (src1->type );
29733002    const  uint64_t  x_sz = !qx_needs_dequant ? qx_sz : sizeof (ggml_fp16_t ) * x_ne;
@@ -2993,7 +3022,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
29933022    if  (dryrun) {
29943023        const  uint64_t  x_sz_upd = x_sz * ne02 * ne03;
29953024        const  uint64_t  y_sz_upd = y_sz * ne12 * ne13;
2996-         const  uint64_t  split_k_size = split_k > 1  ? d_sz * ne12 * ne13 * 4  : 0 ;
3025+         const  uint64_t  split_k_size = split_k > 1  ? d_sz * ne12 * ne13 * split_k  : 0 ;
29973026        if  (
29983027                (qx_needs_dequant && x_sz_upd > ctx->device ->max_memory_allocation_size ) ||
29993028                (qy_needs_dequant && y_sz_upd > ctx->device ->max_memory_allocation_size ) ||
0 commit comments