Skip to content

Commit 945a254

Browse files
committed
cuda: fix layer split mode preventing cuda graph compilation
1 parent 6f180b9 commit 945a254

File tree

2 files changed

+57
-7
lines changed

2 files changed

+57
-7
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,7 @@ struct ggml_backend_cuda_context {
763763
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
764764

765765
std::unique_ptr<ggml_cuda_graph> cuda_graph;
766+
std::vector<ggml_cuda_graph *> cuda_graphs;
766767

767768
explicit ggml_backend_cuda_context(int device) :
768769
device(device),
@@ -783,6 +784,13 @@ struct ggml_backend_cuda_context {
783784
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
784785
}
785786
}
787+
while (!cuda_graphs.empty()) {
788+
auto graph = cuda_graphs.back();
789+
cuda_graphs.pop_back();
790+
if (graph != nullptr) {
791+
delete graph;
792+
}
793+
}
786794
}
787795

788796
cudaStream_t stream(int device, int stream) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2581,34 +2581,49 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
25812581
return true;
25822582
}
25832583

2584-
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2584+
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, ggml_cuda_graph * cuda_graph, bool update_properties) {
25852585

25862586
bool cuda_graph_update_required = false;
25872587

2588-
if (cuda_ctx->cuda_graph->instance == nullptr) {
2588+
if (cuda_graph->instance == nullptr) {
25892589
cuda_graph_update_required = true;
2590+
if (!update_properties) {
2591+
return cuda_graph_update_required;
2592+
}
25902593
}
25912594

25922595
// Check if the graph size has changed
2593-
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
2596+
if (cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
25942597
cuda_graph_update_required = true;
2595-
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2598+
if (update_properties) {
2599+
cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2600+
}
2601+
else {
2602+
return cuda_graph_update_required;
2603+
}
25962604
}
25972605

25982606
// Loop over nodes in GGML graph to determine if CUDA graph update is required
25992607
// and store properties to allow this comparison for the next token
26002608
for (int i = 0; i < cgraph->n_nodes; i++) {
26012609
bool has_matching_properties = true;
26022610
if (!cuda_graph_update_required) {
2603-
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2611+
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]);
26042612
}
26052613
if (!has_matching_properties) {
26062614
cuda_graph_update_required = true;
2615+
if (!update_properties) {
2616+
return cuda_graph_update_required;
2617+
}
2618+
}
2619+
if (update_properties) {
2620+
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]);
26072621
}
2608-
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
26092622
}
26102623

26112624
return cuda_graph_update_required;
2625+
2626+
GGML_UNUSED(cuda_ctx);
26122627
}
26132628

26142629
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
@@ -2714,6 +2729,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
27142729
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
27152730
}
27162731

2732+
// the input node may change to a different address in layer split
2733+
// mode which cuases the graph to be invalidated. cache some number of graphs
2734+
// and search them all.
2735+
while (cuda_ctx->cuda_graphs.size() < 4) {
2736+
cuda_ctx->cuda_graphs.emplace_back(new ggml_cuda_graph());
2737+
}
2738+
27172739
bool use_cuda_graph = true;
27182740
bool cuda_graph_update_required = false;
27192741

@@ -2737,7 +2759,27 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
27372759
}
27382760

27392761
if (use_cuda_graph) {
2740-
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
2762+
// find a matching graph, testing most recent one first, then check lru
2763+
if (is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_ctx->cuda_graph.get(), false)) {
2764+
for (size_t graph_index = 0; graph_index < cuda_ctx->cuda_graphs.size(); graph_index++) {
2765+
auto cuda_graph = cuda_ctx->cuda_graphs[graph_index];
2766+
2767+
if (graph_index == cuda_ctx->cuda_graphs.size() - 1) {
2768+
cuda_ctx->cuda_graphs.erase(cuda_ctx->cuda_graphs.begin() + graph_index);
2769+
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph, true);
2770+
ggml_cuda_graph * existing = cuda_ctx->cuda_graph.release();
2771+
cuda_ctx->cuda_graph.reset(cuda_graph);
2772+
cuda_ctx->cuda_graphs.insert(cuda_ctx->cuda_graphs.begin(), existing);
2773+
break;
2774+
} else if (!is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph, false)) {
2775+
cuda_ctx->cuda_graphs.erase(cuda_ctx->cuda_graphs.begin() + graph_index);
2776+
ggml_cuda_graph * existing = cuda_ctx->cuda_graph.release();
2777+
cuda_ctx->cuda_graph.reset(cuda_graph);
2778+
cuda_ctx->cuda_graphs.insert(cuda_ctx->cuda_graphs.begin(), existing);
2779+
break;
2780+
}
2781+
}
2782+
}
27412783

27422784
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
27432785

0 commit comments

Comments
 (0)