Skip to content

Commit 22c2429

Browse files
committed
Refactor: Moves cuda graph update check to separate function.
1 parent ba05331 commit 22c2429

File tree

1 file changed

+31
-22
lines changed

1 file changed

+31
-22
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,6 +2337,36 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
23372337
}
23382338
#endif
23392339

2340+
#ifdef USE_CUDA_GRAPH
2341+
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) {
2342+
2343+
if (cuda_ctx->cuda_graph->instance == nullptr) {
2344+
cuda_graph_update_required = true;
2345+
}
2346+
2347+
// Check if the graph size has changed
2348+
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
2349+
cuda_graph_update_required = true;
2350+
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2351+
}
2352+
2353+
// Loop over nodes in GGML graph to determine if CUDA graph update is required
2354+
// and store properties to allow this comparison for the next token
2355+
for (int i = 0; i < cgraph->n_nodes; i++) {
2356+
bool has_matching_properties = true;
2357+
if (!cuda_graph_update_required) {
2358+
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2359+
}
2360+
if (!has_matching_properties) {
2361+
cuda_graph_update_required = true;
2362+
}
2363+
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2364+
}
2365+
2366+
return cuda_graph_update_required;
2367+
}
2368+
#endif
2369+
23402370

23412371
#ifdef USE_CUDA_GRAPH
23422372
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
@@ -2398,28 +2428,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
23982428
}
23992429

24002430
if (use_cuda_graph) {
2401-
if (cuda_ctx->cuda_graph->instance == nullptr) {
2402-
cuda_graph_update_required = true;
2403-
}
2404-
2405-
// Check if the graph size has changed
2406-
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
2407-
cuda_graph_update_required = true;
2408-
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2409-
}
2410-
2411-
// Loop over nodes in GGML graph to determine if CUDA graph update is required
2412-
// and store properties to allow this comparison for the next token
2413-
for (int i = 0; i < cgraph->n_nodes; i++) {
2414-
bool has_matching_properties = true;
2415-
if (!cuda_graph_update_required) {
2416-
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2417-
}
2418-
if (!has_matching_properties) {
2419-
cuda_graph_update_required = true;
2420-
}
2421-
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2422-
}
2431+
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph_update_required);
24232432

24242433
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
24252434
cuda_ctx->cuda_graph->updated_kernel_arg.clear();

0 commit comments

Comments
 (0)