@@ -2589,6 +2589,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
25892589
25902590 // Loop over nodes in GGML graph to obtain info needed for CUDA graph
25912591 cuda_ctx->cuda_graph ->cpy_dest_ptrs .clear ();
2592+ std::uint8_t batch_size_counter = 0 ;
25922593
25932594 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
25942595 ggml_tensor * node = cgraph->nodes [i];
@@ -2612,12 +2613,18 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
26122613 }
26132614
26142615 if (node->op == GGML_OP_ADD && node->src [1 ] && node->src [1 ]->ne [1 ] > 1 ) {
2615- // disable CUDA graphs for batch size > 1 for now.
2616- // Changes in batch size or context size can cause changes to the grid size of some kernels.
2617- use_cuda_graph = false ;
2618- #ifndef NDEBUG
2619- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n " , __func__, node->name , node->ne [0 ], node->ne [1 ], node->ne [2 ], node->ne [3 ]);
2620- #endif
2616+ // disable CUDA graphs for batch size > 1 for now. The heuristic here allows to use CUDA graphs
2617+ // for Gemma3n, which uses a single Matrix-Matrix Addition as part of `project_per_layer_input`, while detecting
2618+ // batched execution for all graphs with >1 GGML_OP_ADD nodes. See also
2619+ // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2620+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2621+ ++batch_size_counter;
2622+ if (batch_size_counter > 1 ) {
2623+ use_cuda_graph = false ;
2624+ #ifndef NDEBUG
2625+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to repeated batch size > 1 [%s] [%ld %ld %ld %ld]\n " , __func__, node->name , node->ne [0 ], node->ne [1 ], node->ne [2 ], node->ne [3 ]);
2626+ #endif
2627+ }
26212628 }
26222629
26232630 if (node->op == GGML_OP_CPY) {
0 commit comments