@@ -8436,8 +8436,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
84368436 VK_LOG_DEBUG (" ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)" );
84378437 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context ;
84388438
8439+ uint64_t total_mat_mul_bytes = 0 ;
84398440 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
84408441 ggml_vk_build_graph (ctx, cgraph->nodes [i], i, nullptr , 0 , true , false , false );
8442+ if (cgraph->nodes [i]->op == GGML_OP_MUL_MAT || cgraph->nodes [i]->op == GGML_OP_MUL_MAT_ID) {
8443+ total_mat_mul_bytes += ggml_nbytes (cgraph->nodes [i]->src [0 ]);
8444+ }
84418445 }
84428446 if (ctx->device ->need_compiles ) {
84438447 ggml_vk_load_shaders (ctx->device );
@@ -8458,17 +8462,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
84588462 bool first_node_in_batch = true ; // true if next node will be first node in a batch
84598463 int submit_node_idx = 0 ; // index to first node in a batch
84608464
8461- // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
8462- // Start with a smaller count to get work submitted right away, and increase it after each submit.
8463- int nodes_per_submit = 20 ;
8465+ // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
8466+ // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
8467+ // (and scaled down based on model size, so smaller models submit earlier).
8468+ // Also submit at least every 100 nodes, in case there are workloads without as much matmul.
8469+ int nodes_per_submit = 100 ;
84648470 int submitted_nodes = 0 ;
84658471 int submit_count = 0 ;
8472+ uint64_t mul_mat_bytes = 0 ;
8473+ uint64_t mul_mat_bytes_per_submit = std::min (uint64_t (100 *1000 *1000 ), total_mat_mul_bytes / 40u );
84668474 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
84678475 if (first_node_in_batch) {
84688476 submit_node_idx = i;
84698477 }
84708478
8471- bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
8479+ if (cgraph->nodes [i]->op == GGML_OP_MUL_MAT || cgraph->nodes [i]->op == GGML_OP_MUL_MAT_ID) {
8480+ mul_mat_bytes += ggml_nbytes (cgraph->nodes [i]->src [0 ]);
8481+ }
8482+
8483+ bool submit = (submitted_nodes >= nodes_per_submit) ||
8484+ (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
8485+ (i == last_node);
84728486
84738487 bool enqueued = ggml_vk_build_graph (ctx, cgraph->nodes [i], i, cgraph->nodes [submit_node_idx], submit_node_idx, false , i == last_node, submit);
84748488
@@ -8485,13 +8499,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
84858499 if (submit) {
84868500 first_node_in_batch = true ;
84878501 submitted_nodes = 0 ;
8488- switch (submit_count) {
8489- case 0 :
8490- nodes_per_submit = 50 ;
8491- break ;
8492- default :
8493- nodes_per_submit = 100 ;
8494- break ;
8502+ mul_mat_bytes = 0 ;
8503+ if (submit_count < 3 ) {
8504+ mul_mat_bytes_per_submit *= 2 ;
84958505 }
84968506 submit_count++;
84978507 }
0 commit comments