@@ -2581,34 +2581,49 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
25812581 return true ;
25822582}
25832583
2584- static bool is_cuda_graph_update_required (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2584+ static bool is_cuda_graph_update_required (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, ggml_cuda_graph * cuda_graph, bool update_properties ) {
25852585
25862586 bool cuda_graph_update_required = false ;
25872587
2588- if (cuda_ctx-> cuda_graph ->instance == nullptr ) {
2588+ if (cuda_graph->instance == nullptr ) {
25892589 cuda_graph_update_required = true ;
2590+ if (!update_properties) {
2591+ return cuda_graph_update_required;
2592+ }
25902593 }
25912594
25922595 // Check if the graph size has changed
2593- if (cuda_ctx-> cuda_graph ->ggml_graph_properties .size () != (size_t )cgraph->n_nodes ) {
2596+ if (cuda_graph->ggml_graph_properties .size () != (size_t )cgraph->n_nodes ) {
25942597 cuda_graph_update_required = true ;
2595- cuda_ctx->cuda_graph ->ggml_graph_properties .resize (cgraph->n_nodes );
2598+ if (update_properties) {
2599+ cuda_graph->ggml_graph_properties .resize (cgraph->n_nodes );
2600+ }
2601+ else {
2602+ return cuda_graph_update_required;
2603+ }
25962604 }
25972605
25982606 // Loop over nodes in GGML graph to determine if CUDA graph update is required
25992607 // and store properties to allow this comparison for the next token
26002608 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
26012609 bool has_matching_properties = true ;
26022610 if (!cuda_graph_update_required) {
2603- has_matching_properties = ggml_graph_node_has_matching_properties (cgraph->nodes [i], &cuda_ctx-> cuda_graph ->ggml_graph_properties [i]);
2611+ has_matching_properties = ggml_graph_node_has_matching_properties (cgraph->nodes [i], &cuda_graph->ggml_graph_properties [i]);
26042612 }
26052613 if (!has_matching_properties) {
26062614 cuda_graph_update_required = true ;
2615+ if (!update_properties) {
2616+ return cuda_graph_update_required;
2617+ }
2618+ }
2619+ if (update_properties) {
2620+ set_ggml_graph_node_properties (cgraph->nodes [i], &cuda_graph->ggml_graph_properties [i]);
26072621 }
2608- set_ggml_graph_node_properties (cgraph->nodes [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [i]);
26092622 }
26102623
26112624 return cuda_graph_update_required;
2625+
2626+ GGML_UNUSED (cuda_ctx);
26122627}
26132628
26142629static void update_cuda_graph_executable (ggml_backend_cuda_context * cuda_ctx) {
@@ -2714,6 +2729,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
27142729 cuda_ctx->cuda_graph .reset (new ggml_cuda_graph ());
27152730 }
27162731
2732+ // the input node may change to a different address in layer split
2733+ // mode which cuases the graph to be invalidated. cache some number of graphs
2734+ // and search them all.
2735+ while (cuda_ctx->cuda_graphs .size () < 4 ) {
2736+ cuda_ctx->cuda_graphs .emplace_back (new ggml_cuda_graph ());
2737+ }
2738+
27172739 bool use_cuda_graph = true ;
27182740 bool cuda_graph_update_required = false ;
27192741
@@ -2737,7 +2759,27 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
27372759 }
27382760
27392761 if (use_cuda_graph) {
2740- cuda_graph_update_required = is_cuda_graph_update_required (cuda_ctx, cgraph);
2762+ // find a matching graph, testing most recent one first, then check lru
2763+ if (is_cuda_graph_update_required (cuda_ctx, cgraph, cuda_ctx->cuda_graph .get (), false )) {
2764+ for (size_t graph_index = 0 ; graph_index < cuda_ctx->cuda_graphs .size (); graph_index++) {
2765+ auto cuda_graph = cuda_ctx->cuda_graphs [graph_index];
2766+
2767+ if (graph_index == cuda_ctx->cuda_graphs .size () - 1 ) {
2768+ cuda_ctx->cuda_graphs .erase (cuda_ctx->cuda_graphs .begin () + graph_index);
2769+ cuda_graph_update_required = is_cuda_graph_update_required (cuda_ctx, cgraph, cuda_graph, true );
2770+ ggml_cuda_graph * existing = cuda_ctx->cuda_graph .release ();
2771+ cuda_ctx->cuda_graph .reset (cuda_graph);
2772+ cuda_ctx->cuda_graphs .insert (cuda_ctx->cuda_graphs .begin (), existing);
2773+ break ;
2774+ } else if (!is_cuda_graph_update_required (cuda_ctx, cgraph, cuda_graph, false )) {
2775+ cuda_ctx->cuda_graphs .erase (cuda_ctx->cuda_graphs .begin () + graph_index);
2776+ ggml_cuda_graph * existing = cuda_ctx->cuda_graph .release ();
2777+ cuda_ctx->cuda_graph .reset (cuda_graph);
2778+ cuda_ctx->cuda_graphs .insert (cuda_ctx->cuda_graphs .begin (), existing);
2779+ break ;
2780+ }
2781+ }
2782+ }
27412783
27422784 use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops (cuda_ctx, cgraph, use_cuda_graph);
27432785
0 commit comments