@@ -3173,12 +3173,13 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31733173 GGML_ASSERT (concurrent_event->stream_mapping .find (node) != concurrent_event->stream_mapping .end ());
31743174 const int stream_mapping = concurrent_event->stream_mapping [node];
31753175 cuda_ctx->curr_stream_no = stream_mapping;
3176+ GGML_LOG_DEBUG (" Setting stream no to %d for node %s\n " , stream_mapping, node->name );
31763177 }
31773178 }
3178-
3179- #ifdef GGML_CUDA_DEBUG
31803179 const int nodes_fused = i - prev_i - 1 ;
31813180 prev_i = i;
3181+
3182+ #ifdef GGML_CUDA_DEBUG
31823183 if (nodes_fused > 0 ) {
31833184 GGML_LOG_INFO (" nodes_fused: %d\n " , nodes_fused);
31843185 }
@@ -3459,16 +3460,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
34593460 continue ;
34603461 }
34613462
3463+ // TODO: fix this
3464+ static const bool graph_opt = (getenv (" GGML_CUDA_GRAPH_OPT" ) != nullptr ) && atoi (getenv (" GGML_CUDA_GRAPH_OPT" )) == 1 ;
3465+
34623466 if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3463- ggml_cuda_op_rms_norm_fused_add (*cuda_ctx, node, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
3464- i += 2 ;
3465- continue ;
3467+ if (strncmp (cgraph->nodes [i+2 ]->name , " attn_norm" , strlen (" attn_norm" )) != 0 || !graph_opt) {
3468+ ggml_cuda_op_rms_norm_fused_add (*cuda_ctx, node, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
3469+ i += 2 ;
3470+ continue ;
3471+ }
34663472 }
34673473
34683474 if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
3469- ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
3470- i++;
3471- continue ;
3475+ if (strncmp (cgraph->nodes [i+1 ]->name , " attn_norm" , strlen (" attn_norm" )) != 0 || !graph_opt) {
3476+ ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
3477+ i++;
3478+ continue ;
3479+ }
34723480 }
34733481
34743482 if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
@@ -3506,7 +3514,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
35063514 if (stream_ctx.find (adjusted_node) != stream_ctx.end ()) {
35073515 concurrent_event = &stream_ctx[adjusted_node];
35083516
3509- GGML_LOG_DEBUG (" Launching %d streams at %s\n " , concurrent_event->n_streams , node ->name );
3517+ GGML_LOG_DEBUG (" Launching %d streams at %s\n " , concurrent_event->n_streams , adjusted_node ->name );
35103518
35113519 cudaStream_t main_stream = cuda_ctx->stream (); // this should be stream 0
35123520 GGML_ASSERT (cuda_ctx->curr_stream_no == 0 );
@@ -3520,7 +3528,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
35203528 is_concurrent_event_active = true ;
35213529 }
35223530 }
3523- prev_i = i;
35243531 }
35253532 }
35263533
0 commit comments