@@ -165,6 +165,7 @@ struct vk_device_struct {
165165 vk_queue transfer_queue;
166166 bool single_queue;
167167 uint32_t subgroup_size;
168+ uint32_t shader_core_count;
168169 bool uma;
169170
170171 size_t idx;
@@ -1498,7 +1499,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
14981499 ggml_vk_create_pipeline (device, device->pipeline_get_rows_f32 [GGML_TYPE_Q8_0], " get_rows_q8_0_f32" , get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {1024 , 1 , 1 }, {}, 1 );
14991500 ggml_vk_create_pipeline (device, device->pipeline_get_rows_f32 [GGML_TYPE_IQ4_NL], " get_rows_iq4_nl_f32" , get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {1024 , 1 , 1 }, {}, 1 );
15001501
1501- ggml_vk_create_pipeline (device, device->pipeline_matmul_split_k_reduce , " split_k_reduce" , split_k_reduce_len, split_k_reduce_data, " main" , 2 , 2 * sizeof (uint32_t ), {256 , 1 , 1 }, {}, 1 );
1502+ ggml_vk_create_pipeline (device, device->pipeline_matmul_split_k_reduce , " split_k_reduce" , split_k_reduce_len, split_k_reduce_data, " main" , 2 , 2 * sizeof (uint32_t ), {256 * 4 , 1 , 1 }, {}, 1 );
15021503
15031504 ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_p021_f16_f32 , " mul_mat_vec_p021_f16_f32" , mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, " main" , 3 , 6 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
15041505 ggml_vk_create_pipeline (device, device->pipeline_mul_mat_vec_nc_f16_f32 , " mul_mat_vec_nc_f16_f32" , mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, " main" , 3 , 7 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 );
@@ -1610,23 +1611,36 @@ static vk_device ggml_vk_get_device(size_t idx) {
16101611 const std::vector<vk::ExtensionProperties> ext_props = device->physical_device .enumerateDeviceExtensionProperties ();
16111612
16121613 bool maintenance4_support = false ;
1614+ bool sm_builtins = false ;
16131615
16141616 // Check if maintenance4 is supported
16151617 for (const auto & properties : ext_props) {
16161618 if (strcmp (" VK_KHR_maintenance4" , properties.extensionName ) == 0 ) {
16171619 maintenance4_support = true ;
1620+ } else if (strcmp (" VK_NV_shader_sm_builtins" , properties.extensionName ) == 0 ) {
1621+ sm_builtins = true ;
16181622 }
16191623 }
16201624
16211625 vk::PhysicalDeviceProperties2 props2;
16221626 vk::PhysicalDeviceMaintenance3Properties props3;
16231627 vk::PhysicalDeviceMaintenance4Properties props4;
16241628 vk::PhysicalDeviceSubgroupProperties subgroup_props;
1629+ vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
16251630 props2.pNext = &props3;
16261631 props3.pNext = &subgroup_props;
1632+
1633+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props;
1634+
16271635 if (maintenance4_support) {
1628- subgroup_props.pNext = &props4;
1636+ last_struct->pNext = (VkBaseOutStructure *)&props4;
1637+ last_struct = (VkBaseOutStructure *)&props4;
1638+ }
1639+ if (sm_builtins) {
1640+ last_struct->pNext = (VkBaseOutStructure *)&sm_props;
1641+ last_struct = (VkBaseOutStructure *)&sm_props;
16291642 }
1643+
16301644 device->physical_device .getProperties2 (&props2);
16311645 device->properties = props2.properties ;
16321646
@@ -1643,6 +1657,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
16431657 device->vendor_id = device->properties .vendorID ;
16441658 device->subgroup_size = subgroup_props.subgroupSize ;
16451659 device->uma = device->properties .deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1660+ if (sm_builtins) {
1661+ device->shader_core_count = sm_props.shaderSMCount ;
1662+ } else {
1663+ device->shader_core_count = 0 ;
1664+ }
16461665
16471666 bool fp16_storage = false ;
16481667 bool fp16_compute = false ;
@@ -2732,15 +2751,25 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
27322751 dst->device ->device .resetFences ({ dst->device ->fence });
27332752}
27342753
2735- static uint32_t ggml_vk_guess_split_k (int m, int n, int k) {
2754+ static uint32_t ggml_vk_guess_split_k (ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline ) {
27362755 VK_LOG_DEBUG (" ggml_vk_guess_split_k(" << m << " , " << n << " , " << k << " )" );
2737- // if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
2738- // return 4;
2739- // }
27402756
2741- return 1 ;
2757+ uint32_t split_k = 1 ;
2758+ if (ctx->device ->shader_core_count != 0 && m >= (int )pipeline->wg_denoms [0 ] && n >= (int )pipeline->wg_denoms [1 ]) {
2759+ // If k is 'large' and the SMs will fill less than halfway, use split_k.
2760+ uint32_t m_tiles = CEIL_DIV (m, pipeline->wg_denoms [0 ]);
2761+ uint32_t n_tiles = CEIL_DIV (n, pipeline->wg_denoms [1 ]);
2762+ if (k >= 2048 && m_tiles * n_tiles < ctx->device ->shader_core_count / 2 ) {
2763+ split_k = ctx->device ->shader_core_count / (m_tiles * n_tiles);
2764+ // Clamp to 2 or 4
2765+ split_k = std::min (split_k, 4u );
2766+ if (split_k == 3 ) {
2767+ split_k = 2 ;
2768+ }
2769+ }
2770+ }
27422771
2743- GGML_UNUSED (m); GGML_UNUSED (n); GGML_UNUSED (k) ;
2772+ return split_k ;
27442773}
27452774
27462775static vk_pipeline ggml_vk_guess_matmul_pipeline_amd (ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
@@ -2964,10 +2993,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
29642993 const uint32_t kpad = ggml_vk_align_size (ne10, ggml_vk_guess_matmul_pipeline_align (ctx, mmp, ne01, ne11));
29652994 const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8 ;
29662995
2967- const uint32_t split_k = ggml_vk_guess_split_k (ne01, ne11, ne10);
2968-
29692996 vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline (ctx, mmp, ne01, ne11, aligned);
29702997
2998+ const uint32_t split_k = ggml_vk_guess_split_k (ctx, ne01, ne11, ne10, pipeline);
2999+
29713000 const uint64_t qx_sz = ggml_type_size (src0->type ) * x_ne / ggml_blck_size (src0->type );
29723001 const uint64_t qy_sz = ggml_type_size (src1->type ) * y_ne / ggml_blck_size (src1->type );
29733002 const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof (ggml_fp16_t ) * x_ne;
@@ -2993,7 +3022,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
29933022 if (dryrun) {
29943023 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
29953024 const uint64_t y_sz_upd = y_sz * ne12 * ne13;
2996- const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0 ;
3025+ const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0 ;
29973026 if (
29983027 (qx_needs_dequant && x_sz_upd > ctx->device ->max_memory_allocation_size ) ||
29993028 (qy_needs_dequant && y_sz_upd > ctx->device ->max_memory_allocation_size ) ||
0 commit comments