@@ -2642,8 +2642,8 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
26422642}
26432643
26442644#ifdef USE_CUDA_GRAPH
2645- static bool check_node_graph_compatibility (const ggml_cgraph * cgraph,
2646- bool use_cuda_graph) {
2645+ static bool check_node_graph_compatibility (const ggml_cgraph * cgraph) {
2646+ bool use_cuda_graph = true ;
26472647
26482648 // Loop over nodes in GGML graph to obtain info needed for CUDA graph
26492649
@@ -2753,8 +2753,14 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
27532753 return true ;
27542754}
27552755
2756- static bool is_cuda_graph_update_required (ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
2756+ static void update_cuda_graph_properties (ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
2757+ cuda_graph->ggml_graph_properties .resize (cgraph->n_nodes );
2758+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2759+ set_ggml_graph_node_properties (cgraph->nodes [i], &cuda_graph->ggml_graph_properties [i]);
2760+ }
2761+ }
27572762
2763+ static bool is_cuda_graph_update_required (ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
27582764 bool cuda_graph_update_required = false ;
27592765
27602766 if (cuda_graph->instance == nullptr ) {
@@ -2768,7 +2774,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg
27682774 }
27692775
27702776 // Loop over nodes in GGML graph to determine if CUDA graph update is required
2771- // and store properties to allow this comparison for the next token
27722777 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
27732778 bool has_matching_properties = true ;
27742779 if (!cuda_graph_update_required) {
@@ -2777,7 +2782,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg
27772782 if (!has_matching_properties) {
27782783 cuda_graph_update_required = true ;
27792784 }
2780- set_ggml_graph_node_properties (cgraph->nodes [i], &cuda_graph->ggml_graph_properties [i]);
27812785 }
27822786
27832787 return cuda_graph_update_required;
@@ -3057,22 +3061,14 @@ static void capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_g
30573061 }
30583062}
30593063
3060- static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create (ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
3061- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context ;
3062-
3063- ggml_cuda_set_device (cuda_ctx->device );
3064-
3065- static const bool disable_cuda_graphs_due_to_env = (getenv (" GGML_CUDA_DISABLE_GRAPHS" ) != nullptr );
3066-
3067- ggml_cuda_graph * cuda_graph = new ggml_cuda_graph ();
3068-
3069- cuda_graph->cgraph = cgraph;
3070-
3064+ static bool should_use_cuda_graph (ggml_backend_cuda_context * cuda_ctx, const struct ggml_cgraph * cgraph) {
30713065 bool use_cuda_graph = true ;
30723066
30733067 if (!cuda_ctx->cuda_graph_initialized ) {
3068+ cuda_ctx->disable_graph_due_to_env = (getenv (" GGML_CUDA_DISABLE_GRAPHS" ) != nullptr );
3069+
30743070 if (ggml_cuda_info ().devices [cuda_ctx->device ].cc < GGML_CUDA_CC_AMPERE) {
3075- cuda_ctx->disable_due_to_gpu_arch = true ;
3071+ cuda_ctx->disable_graph_due_to_gpu_arch = true ;
30763072#ifndef NDEBUG
30773073 GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to GPU architecture\n " , __func__);
30783074#endif
@@ -3083,17 +3079,30 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
30833079 // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
30843080 // or previous graph capture failure.
30853081 // Also disable for multi-gpu for now. TO DO investigate
3086- if (disable_cuda_graphs_due_to_env
3087- || cuda_ctx->disable_due_to_gpu_arch ) {
3082+ if (cuda_ctx-> disable_graph_due_to_env || cuda_ctx-> disable_graph_due_to_gpu_arch ||
3083+ cuda_ctx->disable_graph_due_to_too_many_updates ) {
30883084 use_cuda_graph = false ;
30893085 }
30903086
30913087 if (use_cuda_graph) {
3092- use_cuda_graph = check_node_graph_compatibility (cgraph, use_cuda_graph );
3088+ use_cuda_graph = check_node_graph_compatibility (cgraph);
30933089 }
30943090
3095- if (use_cuda_graph) {
3091+ return use_cuda_graph;
3092+ }
3093+
3094+ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create (ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
3095+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context ;
3096+
3097+ ggml_cuda_set_device (cuda_ctx->device );
3098+
3099+ ggml_cuda_graph * cuda_graph = new ggml_cuda_graph ();
3100+
3101+ cuda_graph->cgraph = cgraph;
3102+
3103+ if (should_use_cuda_graph (cuda_ctx, cgraph)) {
30963104 capture_cuda_graph (cuda_ctx, cuda_graph, cgraph);
3105+ update_cuda_graph_properties (cuda_graph, cgraph);
30973106 }
30983107
30993108 return cuda_graph;
@@ -3105,7 +3114,7 @@ static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backe
31053114 GGML_UNUSED (backend);
31063115}
31073116
3108- static void ggml_backend_cuda_graph_plan_update (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const ggml_cgraph* cgraph) {
3117+ static void ggml_backend_cuda_graph_plan_update (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const ggml_cgraph * cgraph) {
31093118 ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context ;
31103119
31113120 ggml_cuda_set_device (cuda_ctx->device );
@@ -3114,15 +3123,28 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
31143123
31153124 cuda_graph->cgraph = cgraph;
31163125
3117- if (!cuda_graph->graph ) {
3118- return ;
3119- }
3120-
3121- bool use_cuda_graph = true ;
3126+ bool use_cuda_graph = should_use_cuda_graph (cuda_ctx, cgraph);
3127+ bool cuda_graph_update_required = false ;
31223128
3123- use_cuda_graph = check_node_graph_compatibility (cgraph, use_cuda_graph);
3129+ // check if we are doing a graph update
3130+ if (cuda_graph->instance == nullptr && use_cuda_graph // no graph -> graph
3131+ || cuda_graph->instance != nullptr && !use_cuda_graph // graph -> no graph
3132+ || use_cuda_graph && is_cuda_graph_update_required (cuda_graph, cgraph)) { // graph property mismatch
3133+ cuda_graph->number_consecutive_updates ++;
3134+ if (cuda_graph->number_consecutive_updates >= 4 ) {
3135+ cuda_ctx->disable_graph_due_to_too_many_updates = true ;
3136+ use_cuda_graph = false ;
3137+ } else {
3138+ cuda_graph_update_required = true ;
3139+ }
3140+ } else {
3141+ cuda_graph->number_consecutive_updates = 0 ;
3142+ }
31243143
3125- if (!use_cuda_graph) {
3144+ if (use_cuda_graph && cuda_graph_update_required) {
3145+ capture_cuda_graph (cuda_ctx, cuda_graph, cgraph);
3146+ update_cuda_graph_properties (cuda_graph, cgraph);
3147+ } else if (!use_cuda_graph) {
31263148 if (cuda_graph->instance != nullptr ) {
31273149 CUDA_CHECK (cudaGraphExecDestroy (cuda_graph->instance ));
31283150 }
@@ -3132,10 +3154,6 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
31323154 cuda_graph->instance = nullptr ;
31333155 cuda_graph->graph = nullptr ;
31343156 }
3135-
3136- if (is_cuda_graph_update_required (cuda_graph, cgraph)) {
3137- capture_cuda_graph (cuda_ctx, cuda_graph, cgraph);
3138- }
31393157}
31403158
31413159static enum ggml_status ggml_backend_cuda_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
0 commit comments