Skip to content

Commit 8e35380

Browse files
committed
Fix Gemma3n not executed as CUDA_GRAPH on NVGPUs
Gemma3n uses Matrix-Matrix addition as part of their input processing, wrongly triggering CUDA_GRAPH disablement on NVGPUs even when batch-size of 1 is used.
1 parent 086cf81 commit 8e35380

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2589,6 +2589,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
25892589

25902590
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
25912591
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
2592+
std::uint8_t batch_size_counter = 0;
25922593

25932594
for (int i = 0; i < cgraph->n_nodes; i++) {
25942595
ggml_tensor * node = cgraph->nodes[i];
@@ -2612,12 +2613,18 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
26122613
}
26132614

26142615
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2615-
// disable CUDA graphs for batch size > 1 for now.
2616-
// Changes in batch size or context size can cause changes to the grid size of some kernels.
2617-
use_cuda_graph = false;
2618-
#ifndef NDEBUG
2619-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2620-
#endif
2616+
// disable CUDA graphs for batch size > 1 for now. The heuristic here allows to use CUDA graphs
2617+
// for Gemma3n, which uses a single Matrix-Matrix Addition as part of `project_per_layer_input`, while detecting
2618+
// batched execution for all graphs with >1 GGML_OP_ADD nodes. See also
2619+
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2620+
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2621+
++batch_size_counter;
2622+
if (batch_size_counter > 1) {
2623+
use_cuda_graph = false;
2624+
#ifndef NDEBUG
2625+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to repeated batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2626+
#endif
2627+
}
26212628
}
26222629

26232630
if (node->op == GGML_OP_CPY) {

0 commit comments

Comments
 (0)