Skip to content

Commit eb3ea69

Browse files
committed
Refactor: Moves cuda graph maintenance (update or adjusting copy parameters) to separate function for improved readability.
1 parent 22c2429 commit eb3ea69

File tree

1 file changed

+50
-43
lines changed

1 file changed

+50
-43
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,6 +2337,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
23372337
}
23382338
#endif
23392339

2340+
2341+
#ifdef USE_CUDA_GRAPH
2342+
static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector<void *> ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
2343+
2344+
if (cuda_graph_update_required) {
2345+
// Extract nodes from graph
2346+
// First call with null argument gets number of nodes in graph
2347+
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2348+
// Subsequent call with non-null argument gets nodes
2349+
cuda_ctx->cuda_graph->nodes.clear();
2350+
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2351+
cuda_ctx->cuda_graph->params.clear();
2352+
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
2353+
if (cuda_ctx->cuda_graph->num_nodes > 0) {
2354+
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
2355+
2356+
// Loop over nodes, and extract kernel parameters from each node
2357+
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2358+
cudaGraphNodeType node_type;
2359+
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
2360+
if (node_type == cudaGraphNodeTypeKernel) {
2361+
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
2362+
if (stat == cudaErrorInvalidDeviceFunction) {
2363+
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2364+
// We don't need to update blas nodes, so clear error and move on.
2365+
cudaGetLastError();
2366+
} else {
2367+
GGML_ASSERT(stat == cudaSuccess);
2368+
}
2369+
}
2370+
}
2371+
}
2372+
} else {
2373+
// One of the arguments to the copy kernel is updated for each token, hence we need to
2374+
// replace that argument with the updated value in the CUDA graph
2375+
// on update steps, the live parameters will already be captured
2376+
int k = 0;
2377+
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2378+
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
2379+
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
2380+
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
2381+
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
2382+
}
2383+
}
2384+
}
2385+
}
2386+
#endif
2387+
2388+
23402389
#ifdef USE_CUDA_GRAPH
23412390
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) {
23422391

@@ -2571,49 +2620,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
25712620
}
25722621

25732622
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
2574-
2575-
if (cuda_graph_update_required) {
2576-
// Extract nodes from graph
2577-
// First call with null argument gets number of nodes in graph
2578-
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2579-
// Subsequent call with non-null argument gets nodes
2580-
cuda_ctx->cuda_graph->nodes.clear();
2581-
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2582-
cuda_ctx->cuda_graph->params.clear();
2583-
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
2584-
if (cuda_ctx->cuda_graph->num_nodes > 0) {
2585-
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
2586-
2587-
// Loop over nodes, and extract kernel parameters from each node
2588-
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2589-
cudaGraphNodeType node_type;
2590-
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
2591-
if (node_type == cudaGraphNodeTypeKernel) {
2592-
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
2593-
if (stat == cudaErrorInvalidDeviceFunction) {
2594-
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2595-
// We don't need to update blas nodes, so clear error and move on.
2596-
cudaGetLastError();
2597-
} else {
2598-
GGML_ASSERT(stat == cudaSuccess);
2599-
}
2600-
}
2601-
}
2602-
}
2603-
}
2604-
2605-
// One of the arguments to the copy kernel is updated for each token, hence we need to
2606-
// replace that argument with the updated value in the CUDA graph
2607-
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2608-
int k = 0;
2609-
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2610-
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
2611-
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
2612-
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
2613-
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
2614-
}
2615-
}
2616-
}
2623+
maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
26172624

26182625
// Update graph executable
26192626
update_cuda_graph_executable(cuda_ctx);

0 commit comments

Comments
 (0)