Skip to content

Commit 21fae96

Browse files
committed
guard to only use indirection with graphs
1 parent c255a0f commit 21fae96

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ struct ggml_cuda_graph {
711711
bool disable_due_to_failed_graph_capture = false;
712712
int number_consecutive_updates = 0;
713713
std::vector<ggml_graph_node_properties> ggml_graph_properties;
714+
bool use_cpy_indirection = false;
714715
std::vector<char *> cpy_dest_ptrs;
715716
char ** dest_ptrs_d;
716717
int dest_ptrs_size = 0;

ggml/src/ggml-cuda/cpy.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,8 +566,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
566566
char ** dest_ptrs_d = nullptr;
567567
int graph_cpynode_index = -1;
568568
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
569-
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
570-
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
569+
if(ctx.cuda_graph->use_cpy_indirection) {
570+
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
571+
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
572+
}
571573
#endif
572574
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
573575
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
@@ -610,7 +612,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
610612
ggml_type_name(src0->type), ggml_type_name(src1->type));
611613
}
612614
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
613-
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
615+
if(ctx.cuda_graph->use_cpy_indirection) {
616+
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
617+
}
614618
#endif
615619

616620
}

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2481,6 +2481,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
24812481

24822482
if(use_cuda_graph)
24832483
{
2484+
cuda_ctx->cuda_graph->use_cpy_indirection = true;
24842485
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
24852486
ggml_backend_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
24862487
}
@@ -2716,6 +2717,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
27162717
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
27172718
}
27182719

2720+
if (!use_cuda_graph) cuda_ctx->cuda_graph->use_cpy_indirection = false;
2721+
27192722
#else
27202723
bool use_cuda_graph = false;
27212724
bool cuda_graph_update_required = false;

0 commit comments

Comments
 (0)