Skip to content

Commit 2c3cfa9

Browse files
committed
fix rms norm fusion causes delays in launching the mul-mat
1 parent a5afc0c commit 2c3cfa9

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3173,12 +3173,13 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31733173
GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
31743174
const int stream_mapping = concurrent_event->stream_mapping[node];
31753175
cuda_ctx->curr_stream_no = stream_mapping;
3176+
GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", stream_mapping, node->name);
31763177
}
31773178
}
3178-
3179-
#ifdef GGML_CUDA_DEBUG
31803179
const int nodes_fused = i - prev_i - 1;
31813180
prev_i = i;
3181+
3182+
#ifdef GGML_CUDA_DEBUG
31823183
if (nodes_fused > 0) {
31833184
GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
31843185
}
@@ -3459,16 +3460,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
34593460
continue;
34603461
}
34613462

3463+
//TODO: fix this
3464+
static const bool graph_opt = (getenv("GGML_CUDA_GRAPH_OPT") != nullptr) && atoi(getenv("GGML_CUDA_GRAPH_OPT")) == 1;
3465+
34623466
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3463-
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3464-
i += 2;
3465-
continue;
3467+
if (strncmp(cgraph->nodes[i+2]->name, "attn_norm", strlen("attn_norm")) != 0 || !graph_opt) {
3468+
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3469+
i += 2;
3470+
continue;
3471+
}
34663472
}
34673473

34683474
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
3469-
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3470-
i++;
3471-
continue;
3475+
if (strncmp(cgraph->nodes[i+1]->name, "attn_norm", strlen("attn_norm")) != 0 || !graph_opt) {
3476+
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3477+
i++;
3478+
continue;
3479+
}
34723480
}
34733481

34743482
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
@@ -3506,7 +3514,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
35063514
if (stream_ctx.find(adjusted_node) != stream_ctx.end()) {
35073515
concurrent_event = &stream_ctx[adjusted_node];
35083516

3509-
GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
3517+
GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, adjusted_node->name);
35103518

35113519
cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
35123520
GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
@@ -3520,7 +3528,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
35203528
is_concurrent_event_active = true;
35213529
}
35223530
}
3523-
prev_i = i;
35243531
}
35253532
}
35263533

0 commit comments

Comments
 (0)