Skip to content

Commit 3afbd9f

Browse files
committed
Fix cuda graph update logic.
1 parent 4bbe5b1 commit 3afbd9f

File tree

3 files changed

+57
-36
lines changed

3 files changed

+57
-36
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,7 @@ struct ggml_cuda_graph {
940940
size_t num_nodes = 0;
941941
std::vector<cudaGraphNode_t> nodes;
942942
std::vector<cudaKernelNodeParams> params;
943+
int number_consecutive_updates = 0;
943944
std::vector<ggml_graph_node_properties> ggml_graph_properties;
944945
#endif
945946
};
@@ -954,7 +955,9 @@ struct ggml_backend_cuda_context {
954955

955956
#ifdef USE_CUDA_GRAPH
956957
bool cuda_graph_initialized = false;
957-
bool disable_due_to_gpu_arch = false;
958+
bool disable_graph_due_to_env = false;
959+
bool disable_graph_due_to_gpu_arch = false;
960+
bool disable_graph_due_to_too_many_updates = false;
958961
#endif
959962

960963
explicit ggml_backend_cuda_context(int device) :

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2642,8 +2642,8 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
26422642
}
26432643

26442644
#ifdef USE_CUDA_GRAPH
2645-
static bool check_node_graph_compatibility(const ggml_cgraph * cgraph,
2646-
bool use_cuda_graph) {
2645+
static bool check_node_graph_compatibility(const ggml_cgraph * cgraph) {
2646+
bool use_cuda_graph = true;
26472647

26482648
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
26492649

@@ -2753,8 +2753,14 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
27532753
return true;
27542754
}
27552755

2756-
static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
2756+
static void update_cuda_graph_properties(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
2757+
cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2758+
for (int i = 0; i < cgraph->n_nodes; i++) {
2759+
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]);
2760+
}
2761+
}
27572762

2763+
static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
27582764
bool cuda_graph_update_required = false;
27592765

27602766
if (cuda_graph->instance == nullptr) {
@@ -2768,7 +2774,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg
27682774
}
27692775

27702776
// Loop over nodes in GGML graph to determine if CUDA graph update is required
2771-
// and store properties to allow this comparison for the next token
27722777
for (int i = 0; i < cgraph->n_nodes; i++) {
27732778
bool has_matching_properties = true;
27742779
if (!cuda_graph_update_required) {
@@ -2777,7 +2782,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg
27772782
if (!has_matching_properties) {
27782783
cuda_graph_update_required = true;
27792784
}
2780-
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]);
27812785
}
27822786

27832787
return cuda_graph_update_required;
@@ -3057,22 +3061,14 @@ static void capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_g
30573061
}
30583062
}
30593063

3060-
static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph* cgraph) {
3061-
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
3062-
3063-
ggml_cuda_set_device(cuda_ctx->device);
3064-
3065-
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
3066-
3067-
ggml_cuda_graph * cuda_graph = new ggml_cuda_graph();
3068-
3069-
cuda_graph->cgraph = cgraph;
3070-
3064+
static bool should_use_cuda_graph(ggml_backend_cuda_context * cuda_ctx, const struct ggml_cgraph * cgraph) {
30713065
bool use_cuda_graph = true;
30723066

30733067
if (!cuda_ctx->cuda_graph_initialized) {
3068+
cuda_ctx->disable_graph_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
3069+
30743070
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
3075-
cuda_ctx->disable_due_to_gpu_arch = true;
3071+
cuda_ctx->disable_graph_due_to_gpu_arch = true;
30763072
#ifndef NDEBUG
30773073
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
30783074
#endif
@@ -3083,17 +3079,30 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
30833079
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
30843080
// or previous graph capture failure.
30853081
// Also disable for multi-gpu for now. TO DO investigate
3086-
if (disable_cuda_graphs_due_to_env
3087-
|| cuda_ctx->disable_due_to_gpu_arch) {
3082+
if (cuda_ctx->disable_graph_due_to_env || cuda_ctx->disable_graph_due_to_gpu_arch ||
3083+
cuda_ctx->disable_graph_due_to_too_many_updates) {
30883084
use_cuda_graph = false;
30893085
}
30903086

30913087
if (use_cuda_graph) {
3092-
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
3088+
use_cuda_graph = check_node_graph_compatibility(cgraph);
30933089
}
30943090

3095-
if (use_cuda_graph) {
3091+
return use_cuda_graph;
3092+
}
3093+
3094+
static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
3095+
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
3096+
3097+
ggml_cuda_set_device(cuda_ctx->device);
3098+
3099+
ggml_cuda_graph * cuda_graph = new ggml_cuda_graph();
3100+
3101+
cuda_graph->cgraph = cgraph;
3102+
3103+
if (should_use_cuda_graph(cuda_ctx, cgraph)) {
30963104
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
3105+
update_cuda_graph_properties(cuda_graph, cgraph);
30973106
}
30983107

30993108
return cuda_graph;
@@ -3105,7 +3114,7 @@ static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backe
31053114
GGML_UNUSED(backend);
31063115
}
31073116

3108-
static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const ggml_cgraph* cgraph) {
3117+
static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const ggml_cgraph * cgraph) {
31093118
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
31103119

31113120
ggml_cuda_set_device(cuda_ctx->device);
@@ -3114,15 +3123,28 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
31143123

31153124
cuda_graph->cgraph = cgraph;
31163125

3117-
if (!cuda_graph->graph) {
3118-
return;
3119-
}
3120-
3121-
bool use_cuda_graph = true;
3126+
bool use_cuda_graph = should_use_cuda_graph(cuda_ctx, cgraph);
3127+
bool cuda_graph_update_required = false;
31223128

3123-
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
3129+
// check if we are doing a graph update
3130+
if (cuda_graph->instance == nullptr && use_cuda_graph // no graph -> graph
3131+
|| cuda_graph->instance != nullptr && !use_cuda_graph // graph -> no graph
3132+
|| use_cuda_graph && is_cuda_graph_update_required(cuda_graph, cgraph)) { // graph property mismatch
3133+
cuda_graph->number_consecutive_updates++;
3134+
if (cuda_graph->number_consecutive_updates >= 4) {
3135+
cuda_ctx->disable_graph_due_to_too_many_updates = true;
3136+
use_cuda_graph = false;
3137+
} else {
3138+
cuda_graph_update_required = true;
3139+
}
3140+
} else {
3141+
cuda_graph->number_consecutive_updates = 0;
3142+
}
31243143

3125-
if (!use_cuda_graph) {
3144+
if (use_cuda_graph && cuda_graph_update_required) {
3145+
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
3146+
update_cuda_graph_properties(cuda_graph, cgraph);
3147+
} else if (!use_cuda_graph) {
31263148
if (cuda_graph->instance != nullptr) {
31273149
CUDA_CHECK(cudaGraphExecDestroy(cuda_graph->instance));
31283150
}
@@ -3132,10 +3154,6 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
31323154
cuda_graph->instance = nullptr;
31333155
cuda_graph->graph = nullptr;
31343156
}
3135-
3136-
if (is_cuda_graph_update_required(cuda_graph, cgraph)) {
3137-
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
3138-
}
31393157
}
31403158

31413159
static enum ggml_status ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {

ggml/src/ggml-cuda/mean.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_g
3434
// CUDA_GRAPHS_DISABLED
3535
((ncols > 65536) &&
3636
((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
37-
ctx.disable_due_to_gpu_arch)) ||
37+
ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates)) ||
3838
// CUDA_GRAPHS ENABLED
3939
((ncols > 32768) &&
4040
!((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
41-
ctx.disable_due_to_gpu_arch))) {
41+
ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates))) {
4242
#else
4343
(ncols > 65536)) {
4444
#endif // USE_CUDA_GRAPH

0 commit comments

Comments
 (0)