@@ -8245,8 +8245,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
82458245 VK_LOG_DEBUG (" ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)" );
82468246 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context ;
82478247
8248+ uint64_t total_mat_mul_bytes = 0 ;
82488249 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
82498250 ggml_vk_build_graph (ctx, cgraph->nodes [i], i, nullptr , 0 , true , false , false );
8251+ if (cgraph->nodes [i]->op == GGML_OP_MUL_MAT || cgraph->nodes [i]->op == GGML_OP_MUL_MAT_ID) {
8252+ total_mat_mul_bytes += ggml_nbytes (cgraph->nodes [i]->src [0 ]);
8253+ }
82508254 }
82518255 if (ctx->device ->need_compiles ) {
82528256 ggml_vk_load_shaders (ctx->device );
@@ -8267,17 +8271,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
82678271 bool first_node_in_batch = true ; // true if next node will be first node in a batch
82688272 int submit_node_idx = 0 ; // index to first node in a batch
82698273
8270- // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
8271- // Start with a smaller count to get work submitted right away, and increase it after each submit.
8272- int nodes_per_submit = 20 ;
8274+ // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
8275+ // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
8276+ // (and scaled down based on model size, so smaller models submit earlier).
8277+ // Also submit at least every 100 nodes, in case there are workloads without as much matmul.
8278+ int nodes_per_submit = 100 ;
82738279 int submitted_nodes = 0 ;
82748280 int submit_count = 0 ;
8281+ uint64_t mul_mat_bytes = 0 ;
8282+ uint64_t mul_mat_bytes_per_submit = std::min (uint64_t (100 *1000 *1000 ), total_mat_mul_bytes / 40u );
82758283 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
82768284 if (first_node_in_batch) {
82778285 submit_node_idx = i;
82788286 }
82798287
8280- bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
8288+ if (cgraph->nodes [i]->op == GGML_OP_MUL_MAT || cgraph->nodes [i]->op == GGML_OP_MUL_MAT_ID) {
8289+ mul_mat_bytes += ggml_nbytes (cgraph->nodes [i]->src [0 ]);
8290+ }
8291+
8292+ bool submit = (submitted_nodes >= nodes_per_submit) ||
8293+ (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
8294+ (i == last_node);
82818295
82828296 bool enqueued = ggml_vk_build_graph (ctx, cgraph->nodes [i], i, cgraph->nodes [submit_node_idx], submit_node_idx, false , i == last_node, submit);
82838297
@@ -8294,13 +8308,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
82948308 if (submit) {
82958309 first_node_in_batch = true ;
82968310 submitted_nodes = 0 ;
8297- switch (submit_count) {
8298- case 0 :
8299- nodes_per_submit = 50 ;
8300- break ;
8301- default :
8302- nodes_per_submit = 100 ;
8303- break ;
8311+ mul_mat_bytes = 0 ;
8312+ if (submit_count < 3 ) {
8313+ mul_mat_bytes_per_submit *= 2 ;
83048314 }
83058315 submit_count++;
83068316 }
0 commit comments