@@ -2337,6 +2337,28 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
23372337}
23382338#endif
23392339
2340+
2341+ #ifdef USE_CUDA_GRAPH
2342+ static void update_cuda_graph_executable (ggml_backend_cuda_context * cuda_ctx) {
2343+
2344+ cudaGraphExecUpdateResultInfo result_info;
2345+ cudaError_t stat = cudaGraphExecUpdate (cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , &result_info);
2346+ if (stat == cudaErrorGraphExecUpdateFailure) {
2347+ #ifndef NDEBUG
2348+ GGML_LOG_DEBUG (" %s: CUDA graph update failed\n " , __func__);
2349+ #endif
2350+ // The pre-existing graph exec cannot be updated due to violated constraints
2351+ // so instead clear error and re-instantiate
2352+ cudaGetLastError ();
2353+ CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
2354+ cuda_ctx->cuda_graph ->instance = nullptr ;
2355+ CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2356+ } else {
2357+ GGML_ASSERT (stat == cudaSuccess);
2358+ }
2359+ }
2360+ #endif
2361+
23402362static enum ggml_status ggml_backend_cuda_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
23412363 ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context ;
23422364
@@ -2585,21 +2607,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
25852607 }
25862608
25872609 // Update graph executable
2588- cudaGraphExecUpdateResultInfo result_info;
2589- cudaError_t stat = cudaGraphExecUpdate (cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , &result_info);
2590- if (stat == cudaErrorGraphExecUpdateFailure) {
2591- #ifndef NDEBUG
2592- GGML_LOG_DEBUG (" %s: CUDA graph update failed\n " , __func__);
2593- #endif
2594- // The pre-existing graph exec cannot be updated due to violated constraints
2595- // so instead clear error and re-instantiate
2596- cudaGetLastError ();
2597- CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
2598- cuda_ctx->cuda_graph ->instance = nullptr ;
2599- CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2600- } else {
2601- GGML_ASSERT (stat == cudaSuccess);
2602- }
2610+ update_cuda_graph_executable (cuda_ctx);
2611+
26032612 // Launch graph
26042613 CUDA_CHECK (cudaGraphLaunch (cuda_ctx->cuda_graph ->instance , cuda_ctx->stream ()));
26052614#else
0 commit comments