@@ -2337,6 +2337,36 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
23372337}
23382338#endif
23392339
2340+ #ifdef USE_CUDA_GRAPH
2341+ static bool is_cuda_graph_update_required (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) {
2342+
2343+ if (cuda_ctx->cuda_graph ->instance == nullptr ) {
2344+ cuda_graph_update_required = true ;
2345+ }
2346+
2347+ // Check if the graph size has changed
2348+ if (cuda_ctx->cuda_graph ->ggml_graph_properties .size () != (size_t )cgraph->n_nodes ) {
2349+ cuda_graph_update_required = true ;
2350+ cuda_ctx->cuda_graph ->ggml_graph_properties .resize (cgraph->n_nodes );
2351+ }
2352+
2353+ // Loop over nodes in GGML graph to determine if CUDA graph update is required
2354+ // and store properties to allow this comparison for the next token
2355+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2356+ bool has_matching_properties = true ;
2357+ if (!cuda_graph_update_required) {
2358+ has_matching_properties = ggml_graph_node_has_matching_properties (cgraph->nodes [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [i]);
2359+ }
2360+ if (!has_matching_properties) {
2361+ cuda_graph_update_required = true ;
2362+ }
2363+ set_ggml_graph_node_properties (cgraph->nodes [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [i]);
2364+ }
2365+
2366+ return cuda_graph_update_required;
2367+ }
2368+ #endif
2369+
23402370
23412371#ifdef USE_CUDA_GRAPH
23422372static void update_cuda_graph_executable (ggml_backend_cuda_context * cuda_ctx) {
@@ -2398,28 +2428,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
23982428 }
23992429
24002430 if (use_cuda_graph) {
2401- if (cuda_ctx->cuda_graph ->instance == nullptr ) {
2402- cuda_graph_update_required = true ;
2403- }
2404-
2405- // Check if the graph size has changed
2406- if (cuda_ctx->cuda_graph ->ggml_graph_properties .size () != (size_t )cgraph->n_nodes ) {
2407- cuda_graph_update_required = true ;
2408- cuda_ctx->cuda_graph ->ggml_graph_properties .resize (cgraph->n_nodes );
2409- }
2410-
2411- // Loop over nodes in GGML graph to determine if CUDA graph update is required
2412- // and store properties to allow this comparison for the next token
2413- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2414- bool has_matching_properties = true ;
2415- if (!cuda_graph_update_required) {
2416- has_matching_properties = ggml_graph_node_has_matching_properties (cgraph->nodes [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [i]);
2417- }
2418- if (!has_matching_properties) {
2419- cuda_graph_update_required = true ;
2420- }
2421- set_ggml_graph_node_properties (cgraph->nodes [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [i]);
2422- }
2431+ cuda_graph_update_required = is_cuda_graph_update_required (cuda_ctx, cgraph, cuda_graph_update_required);
24232432
24242433 // Loop over nodes in GGML graph to obtain info needed for CUDA graph
24252434 cuda_ctx->cuda_graph ->updated_kernel_arg .clear ();
0 commit comments