@@ -2337,6 +2337,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
23372337}
23382338#endif
23392339
2340+
2341+ #ifdef USE_CUDA_GRAPH
2342+ static void maintain_cuda_graph (ggml_backend_cuda_context * cuda_ctx, std::vector<void *> ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
2343+
2344+ if (cuda_graph_update_required) {
2345+ // Extract nodes from graph
2346+ // First call with null argument gets number of nodes in graph
2347+ CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , nullptr , &cuda_ctx->cuda_graph ->num_nodes ));
2348+ // Subsequent call with non-null argument gets nodes
2349+ cuda_ctx->cuda_graph ->nodes .clear ();
2350+ cuda_ctx->cuda_graph ->nodes .resize (cuda_ctx->cuda_graph ->num_nodes );
2351+ cuda_ctx->cuda_graph ->params .clear ();
2352+ cuda_ctx->cuda_graph ->params .resize (cuda_ctx->cuda_graph ->num_nodes );
2353+ if (cuda_ctx->cuda_graph ->num_nodes > 0 ) {
2354+ CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , cuda_ctx->cuda_graph ->nodes .data (), &cuda_ctx->cuda_graph ->num_nodes ));
2355+
2356+ // Loop over nodes, and extract kernel parameters from each node
2357+ for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2358+ cudaGraphNodeType node_type;
2359+ CUDA_CHECK (cudaGraphNodeGetType (cuda_ctx->cuda_graph ->nodes [i], &node_type));
2360+ if (node_type == cudaGraphNodeTypeKernel) {
2361+ cudaError_t stat = cudaGraphKernelNodeGetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]); // Get params using runtime
2362+ if (stat == cudaErrorInvalidDeviceFunction) {
2363+ // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2364+ // We don't need to update blas nodes, so clear error and move on.
2365+ cudaGetLastError ();
2366+ } else {
2367+ GGML_ASSERT (stat == cudaSuccess);
2368+ }
2369+ }
2370+ }
2371+ }
2372+ } else {
2373+ // One of the arguments to the copy kernel is updated for each token, hence we need to
2374+ // replace that argument with the updated value in the CUDA graph
2375+ // on update steps, the live parameters will already be captured
2376+ int k = 0 ;
2377+ for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2378+ if (count (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), cuda_ctx->cuda_graph ->params [i].func ) > 0 ) {
2379+ char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph ->updated_kernel_arg .at (k++);
2380+ cuda_ctx->cuda_graph ->params [i].kernelParams [1 ] = updated_kernel_arg_ptr;
2381+ CUDA_CHECK (cudaGraphKernelNodeSetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]));
2382+ }
2383+ }
2384+ }
2385+ }
2386+ #endif
2387+
2388+
23402389#ifdef USE_CUDA_GRAPH
23412390static bool is_cuda_graph_update_required (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) {
23422391
@@ -2571,49 +2620,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
25712620 }
25722621
25732622 // Perform update to graph (if required for this token), and change copy parameter (required for every token)
2574-
2575- if (cuda_graph_update_required) {
2576- // Extract nodes from graph
2577- // First call with null argument gets number of nodes in graph
2578- CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , nullptr , &cuda_ctx->cuda_graph ->num_nodes ));
2579- // Subsequent call with non-null argument gets nodes
2580- cuda_ctx->cuda_graph ->nodes .clear ();
2581- cuda_ctx->cuda_graph ->nodes .resize (cuda_ctx->cuda_graph ->num_nodes );
2582- cuda_ctx->cuda_graph ->params .clear ();
2583- cuda_ctx->cuda_graph ->params .resize (cuda_ctx->cuda_graph ->num_nodes );
2584- if (cuda_ctx->cuda_graph ->num_nodes > 0 ) {
2585- CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , cuda_ctx->cuda_graph ->nodes .data (), &cuda_ctx->cuda_graph ->num_nodes ));
2586-
2587- // Loop over nodes, and extract kernel parameters from each node
2588- for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2589- cudaGraphNodeType node_type;
2590- CUDA_CHECK (cudaGraphNodeGetType (cuda_ctx->cuda_graph ->nodes [i], &node_type));
2591- if (node_type == cudaGraphNodeTypeKernel) {
2592- cudaError_t stat = cudaGraphKernelNodeGetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]); // Get params using runtime
2593- if (stat == cudaErrorInvalidDeviceFunction) {
2594- // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2595- // We don't need to update blas nodes, so clear error and move on.
2596- cudaGetLastError ();
2597- } else {
2598- GGML_ASSERT (stat == cudaSuccess);
2599- }
2600- }
2601- }
2602- }
2603- }
2604-
2605- // One of the arguments to the copy kernel is updated for each token, hence we need to
2606- // replace that argument with the updated value in the CUDA graph
2607- if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2608- int k = 0 ;
2609- for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2610- if (count (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), cuda_ctx->cuda_graph ->params [i].func ) > 0 ) {
2611- char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph ->updated_kernel_arg .at (k++);
2612- cuda_ctx->cuda_graph ->params [i].kernelParams [1 ] = updated_kernel_arg_ptr;
2613- CUDA_CHECK (cudaGraphKernelNodeSetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]));
2614- }
2615- }
2616- }
2623+ maintain_cuda_graph (cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
26172624
26182625 // Update graph executable
26192626 update_cuda_graph_executable (cuda_ctx);
0 commit comments