@@ -2641,9 +2641,15 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
26412641    GGML_UNUSED (backend);
26422642}
26432643
2644+ //  groups cgraph->nodes offsets per cuda_graph
2645+ struct  cgraph_offset  {
2646+     int  begin;
2647+     int  end;
2648+ };
2649+ 
26442650#ifdef  USE_CUDA_GRAPH
26452651static  bool  check_node_graph_compatibility (ggml_cgraph * cgraph,
2646-     bool  use_cuda_graph) {
2652+     bool  use_cuda_graph, cgraph_offset & offset ) {
26472653
26482654    //  Loop over nodes in GGML graph to obtain info needed for CUDA graph
26492655
@@ -2655,7 +2661,7 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
26552661    const  std::string nemotron_h_block_out_prefix = " nemotron_h_block_out"  ;
26562662    const  std::string mamba2_y_add_d_prefix = " mamba2_y_add_d"  ;
26572663
2658-     for  (int  i = 0 ; i < cgraph-> n_nodes ; i++) {
2664+     for  (int  i = offset. begin ; i < offset. end ; i++) {
26592665        ggml_tensor * node = cgraph->nodes [i];
26602666
26612667        if  (ggml_is_empty (node) || node->op  == GGML_OP_RESHAPE || node->op  == GGML_OP_TRANSPOSE || node->op  == GGML_OP_VIEW || node->op  == GGML_OP_PERMUTE || node->op  == GGML_OP_NONE) {
@@ -2753,45 +2759,45 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
27532759    return  true ;
27542760}
27552761
2756- static  bool  is_cuda_graph_update_required (ggml_backend_cuda_context * cuda_ctx , ggml_cgraph * cgraph) {
2762+ static  bool  is_cuda_graph_update_required (std::unique_ptr<ggml_cuda_graph> & cuda_graph , ggml_cgraph * cgraph, cgraph_offset & offset ) {
27572763
27582764    bool  cuda_graph_update_required = false ;
27592765
2760-     if  (cuda_ctx-> cuda_graph ->instance  == nullptr ) {
2766+     if  (cuda_graph->instance  == nullptr ) {
27612767        cuda_graph_update_required = true ;
27622768    }
27632769
27642770    //  Check if the graph size has changed
2765-     if  (cuda_ctx-> cuda_graph ->ggml_graph_properties .size () != (size_t )cgraph-> n_nodes ) {
2771+     if  (cuda_graph->ggml_graph_properties .size () != (size_t )(offset. end  - offset. begin ) ) {
27662772        cuda_graph_update_required = true ;
2767-         cuda_ctx-> cuda_graph ->ggml_graph_properties .resize (cgraph-> n_nodes );
2773+         cuda_graph->ggml_graph_properties .resize ((offset. end  - offset. begin ) );
27682774    }
27692775
27702776    //  Loop over nodes in GGML graph to determine if CUDA graph update is required
27712777    //  and store properties to allow this comparison for the next token
2772-     for  (int  i = 0 ; i < cgraph-> n_nodes ; i++) {
2778+     for  (int  i = offset. begin ; i < offset. end ; i++) {
27732779        bool  has_matching_properties = true ;
27742780        if  (!cuda_graph_update_required) {
2775-             has_matching_properties = ggml_graph_node_has_matching_properties (cgraph->nodes [i], &cuda_ctx-> cuda_graph ->ggml_graph_properties [i]);
2781+             has_matching_properties = ggml_graph_node_has_matching_properties (cgraph->nodes [i], &cuda_graph->ggml_graph_properties [i - offset. begin ]);
27762782        }
27772783        if  (!has_matching_properties) {
27782784            cuda_graph_update_required = true ;
27792785        }
2780-         set_ggml_graph_node_properties (cgraph->nodes [i], &cuda_ctx-> cuda_graph ->ggml_graph_properties [i]);
2786+         set_ggml_graph_node_properties (cgraph->nodes [i], &cuda_graph->ggml_graph_properties [i - offset. begin ]);
27812787    }
27822788
27832789    return  cuda_graph_update_required;
27842790}
27852791
2786- static  void  update_cuda_graph_executable (ggml_backend_cuda_context * cuda_ctx ) {
2792+ static  void  update_cuda_graph_executable (std::unique_ptr<ggml_cuda_graph> & cuda_graph ) {
27872793
27882794#if  CUDART_VERSION >= 12000
27892795    cudaGraphExecUpdateResultInfo result_info;
2790-     cudaError_t stat = cudaGraphExecUpdate (cuda_ctx-> cuda_graph ->instance , cuda_ctx-> cuda_graph ->graph , &result_info);
2796+     cudaError_t stat = cudaGraphExecUpdate (cuda_graph->instance , cuda_graph->graph , &result_info);
27912797#else 
27922798    cudaGraphNode_t errorNode;
27932799    cudaGraphExecUpdateResult result_info;
2794-     cudaError_t stat = cudaGraphExecUpdate (cuda_ctx-> cuda_graph ->instance , cuda_ctx-> cuda_graph ->graph , &errorNode, &result_info);
2800+     cudaError_t stat = cudaGraphExecUpdate (cuda_graph->instance , cuda_graph->graph , &errorNode, &result_info);
27952801#endif  //  CUDART_VERSION >= 12000
27962802
27972803    if  (stat == cudaErrorGraphExecUpdateFailure) {
@@ -2802,9 +2808,9 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
28022808        //  The pre-existing graph exec cannot be updated due to violated constraints
28032809        //  so instead clear error and re-instantiate
28042810        (void )cudaGetLastError ();
2805-         CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx-> cuda_graph ->instance ));
2806-         cuda_ctx-> cuda_graph ->instance  = nullptr ;
2807-         CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx-> cuda_graph ->instance , cuda_ctx-> cuda_graph ->graph , NULL , NULL , 0 ));
2811+         CUDA_CHECK (cudaGraphExecDestroy (cuda_graph->instance ));
2812+         cuda_graph->instance  = nullptr ;
2813+         CUDA_CHECK (cudaGraphInstantiate (&cuda_graph->instance , cuda_graph->graph , NULL , NULL , 0 ));
28082814    } else  {
28092815        GGML_ASSERT (stat == cudaSuccess);
28102816    }
@@ -2925,8 +2931,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29252931    return  false ;
29262932}
29272933
2928- static  void  evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph ,
2929-     bool  & graph_evaluated_or_captured, bool  & use_cuda_graph, bool  & cuda_graph_update_required) {
2934+ static  void  evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, [[maybe_unused]] std::unique_ptr<ggml_cuda_graph> & cuda_graph ,
2935+     ggml_cgraph * cgraph,  bool  & graph_evaluated_or_captured, bool  & use_cuda_graph, bool  & cuda_graph_update_required, cgraph_offset & offset ) {
29302936    //  flag used to determine whether it is an integrated_gpu
29312937    const  bool  integrated = ggml_cuda_info ().devices [cuda_ctx->device ].integrated ;
29322938
@@ -2935,7 +2941,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29352941        //  With the use of CUDA graphs, the execution will be performed by the graph launch.
29362942        if  (!use_cuda_graph || cuda_graph_update_required) {
29372943
2938-             for  (int  i = 0 ; i < cgraph-> n_nodes ; i++) {
2944+             for  (int  i = offset. begin ; i < offset. end ; i++) {
29392945                ggml_tensor * node = cgraph->nodes [i];
29402946
29412947                if  (ggml_is_empty (node) || node->op  == GGML_OP_RESHAPE || node->op  == GGML_OP_TRANSPOSE || node->op  == GGML_OP_VIEW || node->op  == GGML_OP_PERMUTE || node->op  == GGML_OP_NONE) {
@@ -3034,12 +3040,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
30343040
30353041#ifdef  USE_CUDA_GRAPH
30363042        if  (use_cuda_graph && cuda_graph_update_required) { //  End CUDA graph capture
3037-             if  (cuda_ctx-> cuda_graph ->graph  != nullptr ) {
3038-                 CUDA_CHECK (cudaGraphDestroy (cuda_ctx-> cuda_graph ->graph ));
3039-                 cuda_ctx-> cuda_graph ->graph  = nullptr ;
3043+             if  (cuda_graph->graph  != nullptr ) {
3044+                 CUDA_CHECK (cudaGraphDestroy (cuda_graph->graph ));
3045+                 cuda_graph->graph  = nullptr ;
30403046            }
30413047
3042-             CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_ctx-> cuda_graph ->graph ));
3048+             CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_graph->graph ));
30433049            graph_evaluated_or_captured = true ; //  CUDA graph has been captured
30443050
30453051            std::lock_guard<std::mutex> lock (ggml_cuda_lock);
@@ -3052,14 +3058,14 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
30523058    }
30533059
30543060    if  (use_cuda_graph) {
3055-         if  (cuda_ctx-> cuda_graph ->instance  == nullptr ) { //  Create executable graph from captured graph.
3056-             CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx-> cuda_graph ->instance , cuda_ctx-> cuda_graph ->graph , NULL , NULL , 0 ));
3061+         if  (cuda_graph->instance  == nullptr ) { //  Create executable graph from captured graph.
3062+             CUDA_CHECK (cudaGraphInstantiate (&cuda_graph->instance , cuda_graph->graph , NULL , NULL , 0 ));
30573063        }
30583064        if  (cuda_graph_update_required) { //  Update graph executable
3059-             update_cuda_graph_executable (cuda_ctx );
3065+             update_cuda_graph_executable (cuda_graph );
30603066        }
30613067        //  Launch graph
3062-         CUDA_CHECK (cudaGraphLaunch (cuda_ctx-> cuda_graph ->instance , cuda_ctx->stream ()));
3068+         CUDA_CHECK (cudaGraphLaunch (cuda_graph->instance , cuda_ctx->stream ()));
30633069#else 
30643070        graph_evaluated_or_captured = true ;
30653071#endif  //  USE_CUDA_GRAPH
@@ -3071,74 +3077,107 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
30713077
30723078    ggml_cuda_set_device (cuda_ctx->device );
30733079
3080+     //  Heuristic to minimize GPU idle time. Work is split over several CUDA graphs,
3081+     //   to overlap graph building (CPU) and graph execution (GPU).
3082+     //  The first graphs are small to minimize the time in which the CPU prepares work and the GPU is idle.
3083+     //  After that, graph building (CPU) is done in parallel to the execution of another previously built graph (GPU).
3084+     int  first_graph_subset = 20 ;
3085+     int  second_graph_subset = 50 ;
3086+     int  remaining_graph_subset = 100 ;
3087+     int  remaining_nodes = (cgraph->n_nodes  - first_graph_subset) - second_graph_subset;
3088+     int  num_cuda_graphs_required = 2  + (remaining_nodes / remaining_graph_subset);
3089+     cuda_ctx->cuda_graphs .resize (num_cuda_graphs_required);
3090+     cgraph_offset offset {0 ,0 };
3091+ 
3092+     for  (size_t  i = 0 ; i < cuda_ctx->cuda_graphs .size (); i++) {
3093+         auto  & cuda_graph = cuda_ctx->cuda_graphs [i];
3094+ 
3095+         offset.begin  = offset.end ;
3096+         if  (i == 0 ) offset.end  += first_graph_subset;
3097+         if  (i == 1 ) offset.end  += second_graph_subset;
3098+         if  (i >= 2 ) offset.end  += remaining_graph_subset;
3099+ 
3100+         //  last graph does the rest
3101+         if  ((i + 1 ) == cuda_ctx->cuda_graphs .size ()) offset.end  = cgraph->n_nodes ;
3102+ 
3103+         //  special case for graphs smaller than the ramp-up heuristic
3104+         if  (cgraph->n_nodes  <= first_graph_subset + second_graph_subset) {
3105+             offset.end  = cgraph->n_nodes ;
3106+             if  (i > 0 ) break ;
3107+         }
3108+ 
3109+ 
3110+ 
30743111#ifdef  USE_CUDA_GRAPH
3075-     static  const  bool  disable_cuda_graphs_due_to_env = (getenv (" GGML_CUDA_DISABLE_GRAPHS"  ) != nullptr );
3112+          static  const  bool  disable_cuda_graphs_due_to_env = (getenv (" GGML_CUDA_DISABLE_GRAPHS"  ) != nullptr );
30763113
3077-     //  Objects required for CUDA Graph
3078-     if  (cuda_ctx-> cuda_graph  == nullptr ) {
3079-         cuda_ctx-> cuda_graph . reset ( new   ggml_cuda_graph () );
3080-     }
3114+          //  Objects required for CUDA Graph
3115+          if  (cuda_graph == nullptr ) {
3116+              cuda_graph = std::make_unique< ggml_cuda_graph>( );
3117+          }
30813118
3082-     bool  use_cuda_graph = true ;
3083-     bool  cuda_graph_update_required = false ;
3119+          bool  use_cuda_graph = true ;
3120+          bool  cuda_graph_update_required = false ;
30843121
3085-     if  (cuda_ctx-> cuda_graph ->graph  == nullptr ) {
3086-         if  (ggml_cuda_info ().devices [cuda_ctx->device ].cc  < GGML_CUDA_CC_AMPERE) {
3087-             cuda_ctx-> cuda_graph ->disable_due_to_gpu_arch  = true ;
3122+          if  (cuda_graph->graph  == nullptr ) {
3123+              if  (ggml_cuda_info ().devices [cuda_ctx->device ].cc  < GGML_CUDA_CC_AMPERE) {
3124+                  cuda_graph->disable_due_to_gpu_arch  = true ;
30883125#ifndef  NDEBUG
3089-             GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to GPU architecture\n "  , __func__);
3126+                  GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to GPU architecture\n "  , __func__);
30903127#endif 
3128+             }
30913129        }
3092-     }
30933130
3094-     //  Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
3095-     //  or previous graph capture failure.
3096-     //  Also disable for multi-gpu for now. TO DO investigate
3097-     if  (disable_cuda_graphs_due_to_env
3098-         || cuda_ctx-> cuda_graph ->disable_due_to_gpu_arch 
3099-         || cuda_ctx-> cuda_graph ->disable_due_to_too_many_updates 
3100-         || cuda_ctx-> cuda_graph ->disable_due_to_failed_graph_capture ) {
3101-         use_cuda_graph = false ;
3102-     }
3131+          //  Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
3132+          //  or previous graph capture failure.
3133+          //  Also disable for multi-gpu for now. TO DO investigate
3134+          if  (disable_cuda_graphs_due_to_env
3135+              || cuda_graph->disable_due_to_gpu_arch 
3136+              || cuda_graph->disable_due_to_too_many_updates 
3137+              || cuda_graph->disable_due_to_failed_graph_capture ) {
3138+              use_cuda_graph = false ;
3139+          }
31033140
3104-     if  (use_cuda_graph) {
3105-         cuda_graph_update_required = is_cuda_graph_update_required (cuda_ctx , cgraph);
3141+          if  (use_cuda_graph) {
3142+              cuda_graph_update_required = is_cuda_graph_update_required (cuda_graph , cgraph, offset );
31063143
3107-         use_cuda_graph = check_node_graph_compatibility (cgraph, use_cuda_graph);
3144+              use_cuda_graph = check_node_graph_compatibility (cgraph, use_cuda_graph, offset );
31083145
3109-         //  Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
3110-         if  (use_cuda_graph && cuda_graph_update_required) {
3111-             cuda_ctx-> cuda_graph ->number_consecutive_updates ++;
3112-         } else  {
3113-             cuda_ctx-> cuda_graph ->number_consecutive_updates  = 0 ;
3114-         }
3146+              //  Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
3147+              if  (use_cuda_graph && cuda_graph_update_required) {
3148+                  cuda_graph->number_consecutive_updates ++;
3149+              } else  {
3150+                  cuda_graph->number_consecutive_updates  = 0 ;
3151+              }
31153152
3116-         if  (cuda_ctx-> cuda_graph ->number_consecutive_updates  >= 4 ) {
3117-             cuda_ctx-> cuda_graph ->disable_due_to_too_many_updates  = true ;
3153+              if  (cuda_graph->number_consecutive_updates  >= 4 ) {
3154+                  cuda_graph->disable_due_to_too_many_updates  = true ;
31183155#ifndef  NDEBUG
3119-             GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to too many consecutive updates\n "  , __func__);
3156+                  GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to too many consecutive updates\n "  , __func__);
31203157#endif 
3158+             }
31213159        }
3122-     }
31233160
3124-     if  (use_cuda_graph && cuda_graph_update_required) {
3125-         //  Start CUDA graph capture
3126-         {
3127-             std::lock_guard<std::mutex> lock (ggml_cuda_lock);
3128-             ggml_cuda_lock_counter.fetch_add (1 , std::memory_order_relaxed);
3129-         }
3161+          if  (use_cuda_graph && cuda_graph_update_required) {
3162+              //  Start CUDA graph capture
3163+              {
3164+                  std::lock_guard<std::mutex> lock (ggml_cuda_lock);
3165+                  ggml_cuda_lock_counter.fetch_add (1 , std::memory_order_relaxed);
3166+              }
31303167
3131-         CUDA_CHECK (cudaStreamBeginCapture (cuda_ctx->stream (), cudaStreamCaptureModeRelaxed));
3132-     }
3168+              CUDA_CHECK (cudaStreamBeginCapture (cuda_ctx->stream (), cudaStreamCaptureModeRelaxed));
3169+          }
31333170
31343171#else 
3135-     bool  use_cuda_graph = false ;
3136-     bool  cuda_graph_update_required = false ;
3172+          bool  use_cuda_graph = false ;
3173+          bool  cuda_graph_update_required = false ;
31373174#endif  //  USE_CUDA_GRAPH
31383175
3139-     bool  graph_evaluated_or_captured = false ;
3176+          bool  graph_evaluated_or_captured = false ;
31403177
3141-     evaluate_and_capture_cuda_graph (cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
3178+         evaluate_and_capture_cuda_graph (cuda_ctx, cuda_graph, cgraph,
3179+             graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required, offset);
3180+     }
31423181
31433182    return  GGML_STATUS_SUCCESS;
31443183}
@@ -3896,6 +3935,8 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
38963935        /*  .context = */   ctx,
38973936    };
38983937
3938+     cublasHandle_t cublas_handle = ctx->cublas_handle (device);
3939+ 
38993940    return  cuda_backend;
39003941}
39013942
0 commit comments