@@ -2438,11 +2438,95 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
24382438}
24392439#endif
24402440
2441+
2442+ static void evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2443+ [[maybe_unused]] std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph,
2444+ bool & cuda_graph_update_required) {
2445+
2446+ while (!graph_evaluated_or_captured) {
2447+ // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2448+ // With the use of CUDA graphs, the execution will be performed by the graph launch.
2449+ if (!use_cuda_graph || cuda_graph_update_required) {
2450+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2451+ ggml_tensor * node = cgraph->nodes [i];
2452+
2453+ if (ggml_is_empty (node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2454+ continue ;
2455+ }
2456+
2457+ #ifndef NDEBUG
2458+ assert (node->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ));
2459+ for (int j = 0 ; j < GGML_MAX_SRC; j++) {
2460+ if (node->src [j] != nullptr ) {
2461+ assert (node->src [j]->buffer );
2462+ assert (node->src [j]->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ) ||
2463+ ggml_backend_buft_is_cuda_split (node->src [j]->buffer ->buft ));
2464+ }
2465+ }
2466+ #endif
2467+
2468+ bool ok = ggml_cuda_compute_forward (*cuda_ctx, node);
2469+ if (!ok) {
2470+ GGML_LOG_ERROR (" %s: op not supported %s (%s)\n " , __func__, node->name , ggml_op_name (node->op ));
2471+ }
2472+ GGML_ASSERT (ok);
2473+ }
2474+ }
2475+
2476+ #ifdef USE_CUDA_GRAPH
2477+ if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
2478+ if (cuda_ctx->cuda_graph ->graph != nullptr ) {
2479+ CUDA_CHECK (cudaGraphDestroy (cuda_ctx->cuda_graph ->graph ));
2480+ cuda_ctx->cuda_graph ->graph = nullptr ;
2481+ }
2482+ CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_ctx->cuda_graph ->graph ));
2483+
2484+ #if 0
2485+ if (disable_cuda_graphs_due_to_failed_capture) {
2486+ use_cuda_graph = false;
2487+ cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
2488+ #ifndef NDEBUG
2489+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
2490+ #endif
2491+ } else {
2492+ graph_evaluated_or_captured = true; // CUDA graph has been captured
2493+ }
2494+ #endif
2495+ graph_evaluated_or_captured = true ; // CUDA graph has been captured
2496+ } else {
2497+ graph_evaluated_or_captured = true ; // ggml graph has been directly evaluated
2498+ }
2499+ }
2500+
2501+ if (use_cuda_graph) {
2502+ if (cuda_ctx->cuda_graph ->instance == nullptr ) { // Create executable graph from captured graph.
2503+ CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2504+ }
2505+
2506+ // Perform update to graph (if required for this token), and change copy parameter (required for every token)
2507+ maintain_cuda_graph (cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
2508+
2509+ // Update graph executable
2510+ update_cuda_graph_executable (cuda_ctx);
2511+
2512+ // Launch graph
2513+ CUDA_CHECK (cudaGraphLaunch (cuda_ctx->cuda_graph ->instance , cuda_ctx->stream ()));
2514+ #else
2515+ graph_evaluated_or_captured = true ;
2516+ #endif // USE_CUDA_GRAPH
2517+ }
2518+ }
2519+
2520+
24412521static enum ggml_status ggml_backend_cuda_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
24422522 ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context ;
24432523
24442524 ggml_cuda_set_device (cuda_ctx->device );
24452525
2526+ // vector of pointers to CUDA cpy kernels, which are required to identify
2527+ // kernel parameters which need updated in the graph for each token
2528+ std::vector<void *> ggml_cuda_cpy_fn_ptrs;
2529+
24462530#ifdef USE_CUDA_GRAPH
24472531 static const bool disable_cuda_graphs_due_to_env = (getenv (" GGML_CUDA_DISABLE_GRAPHS" ) != nullptr );
24482532
@@ -2453,9 +2537,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
24532537
24542538 bool use_cuda_graph = true ;
24552539 bool cuda_graph_update_required = false ;
2456- // vector of pointers to CUDA cpy kernels, which are required to identify
2457- // kernel parameters which need updated in the graph for each token
2458- std::vector<void *> ggml_cuda_cpy_fn_ptrs;
24592540
24602541 if (cuda_ctx->cuda_graph ->graph == nullptr ) {
24612542 if (ggml_cuda_info ().devices [cuda_ctx->device ].cc < GGML_CUDA_CC_AMPERE) {
@@ -2559,79 +2640,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
25592640
25602641 bool graph_evaluated_or_captured = false ;
25612642
2562- while (!graph_evaluated_or_captured) {
2563- // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2564- // With the use of CUDA graphs, the execution will be performed by the graph launch.
2565- if (!use_cuda_graph || cuda_graph_update_required) {
2566- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2567- ggml_tensor * node = cgraph->nodes [i];
2568-
2569- if (ggml_is_empty (node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2570- continue ;
2571- }
2572-
2573- #ifndef NDEBUG
2574- assert (node->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ));
2575- for (int j = 0 ; j < GGML_MAX_SRC; j++) {
2576- if (node->src [j] != nullptr ) {
2577- assert (node->src [j]->buffer );
2578- assert (node->src [j]->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ) ||
2579- ggml_backend_buft_is_cuda_split (node->src [j]->buffer ->buft ));
2580- }
2581- }
2582- #endif
2583-
2584- bool ok = ggml_cuda_compute_forward (*cuda_ctx, node);
2585- if (!ok) {
2586- GGML_LOG_ERROR (" %s: op not supported %s (%s)\n " , __func__, node->name , ggml_op_name (node->op ));
2587- }
2588- GGML_ASSERT (ok);
2589- }
2590- }
2591-
2592- #ifdef USE_CUDA_GRAPH
2593- if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
2594- if (cuda_ctx->cuda_graph ->graph != nullptr ) {
2595- CUDA_CHECK (cudaGraphDestroy (cuda_ctx->cuda_graph ->graph ));
2596- cuda_ctx->cuda_graph ->graph = nullptr ;
2597- }
2598- CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_ctx->cuda_graph ->graph ));
2599-
2600- #if 0
2601- if (disable_cuda_graphs_due_to_failed_capture) {
2602- use_cuda_graph = false;
2603- cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
2604- #ifndef NDEBUG
2605- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
2606- #endif
2607- } else {
2608- graph_evaluated_or_captured = true; // CUDA graph has been captured
2609- }
2610- #endif
2611- graph_evaluated_or_captured = true ; // CUDA graph has been captured
2612- } else {
2613- graph_evaluated_or_captured = true ; // ggml graph has been directly evaluated
2614- }
2615- }
2616-
2617- if (use_cuda_graph) {
2618- if (cuda_ctx->cuda_graph ->instance == nullptr ) { // Create executable graph from captured graph.
2619- CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2620- }
2621-
2622- // Perform update to graph (if required for this token), and change copy parameter (required for every token)
2623- maintain_cuda_graph (cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
2624-
2625- // Update graph executable
2626- update_cuda_graph_executable (cuda_ctx);
2627-
2628- // Launch graph
2629- CUDA_CHECK (cudaGraphLaunch (cuda_ctx->cuda_graph ->instance , cuda_ctx->stream ()));
2630- #else
2631- graph_evaluated_or_captured = true ;
2632- #endif // USE_CUDA_GRAPH
2633- }
2634-
2643+ evaluate_and_capture_cuda_graph (cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
26352644 return GGML_STATUS_SUCCESS;
26362645}
26372646
0 commit comments