Skip to content

Commit 61cff1b

Browse files
committed
Enable CUDA graphs for embed gemma 300m
1 parent f549b00 commit 61cff1b

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2806,6 +2806,8 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
28062806
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
28072807
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
28082808
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
2809+
const std::string emGemma_sa_out_prefix = "emGemma_sa_out";
2810+
const std::string emGemma_l_out_prefix = "emGemma_l_out";
28092811

28102812
for (int i = 0; i < cgraph->n_nodes; i++) {
28112813
ggml_tensor * node = cgraph->nodes[i];
@@ -2836,7 +2838,9 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
28362838
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
28372839
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
28382840
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
2839-
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
2841+
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0 &&
2842+
strncmp(node->name, emGemma_sa_out_prefix.c_str(), emGemma_sa_out_prefix.size()) != 0 &&
2843+
strncmp(node->name, emGemma_l_out_prefix.c_str(), emGemma_l_out_prefix.size()) != 0) {
28402844
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
28412845
// by means of matching node names. See
28422846
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and

src/llama-model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11573,7 +11573,7 @@ struct llm_build_gemma_embedding : public llm_graph_context {
1157311573
cb(cur, "attn_post_norm", il);
1157411574

1157511575
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
11576-
cb(sa_out, "sa_out", il);
11576+
cb(sa_out, "emGemma_sa_out", il);
1157711577

1157811578
cur = build_norm(sa_out,
1157911579
model.layers[il].ffn_norm, NULL,
@@ -11599,7 +11599,7 @@ struct llm_build_gemma_embedding : public llm_graph_context {
1159911599
cur = ggml_add(ctx0, cur, sa_out);
1160011600

1160111601
cur = build_cvec(cur, il);
11602-
cb(cur, "l_out", il);
11602+
cb(cur, "emGemma_l_out", il);
1160311603

1160411604
// input for next layer
1160511605
inpL = cur;

0 commit comments

Comments
 (0)