Skip to content

Commit 1129188

Browse files
committed
ggml-cuda: use lambda instead of duplicating code
1 parent 6562b77 commit 1129188

File tree

1 file changed

+20
-37
lines changed

1 file changed

+20
-37
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3192,6 +3192,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31923192
bool is_concurrent_event_active = false;
31933193
ggml_cuda_concurrent_event * concurrent_event = nullptr;
31943194

3195+
const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
3196+
if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
3197+
concurrent_event = &stream_ctx.concurrent_events[node];
3198+
3199+
GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
3200+
3201+
cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
3202+
GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
3203+
CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
3204+
3205+
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3206+
cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
3207+
CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
3208+
}
3209+
}
3210+
};
3211+
31953212
while (!graph_evaluated_or_captured) {
31963213
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
31973214
// With the use of CUDA graphs, the execution will be performed by the graph launch.
@@ -3212,6 +3229,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
32123229
if (node == concurrent_event->join_node) {
32133230
cuda_ctx->curr_stream_no = 0;
32143231
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3232+
// Wait on join events of forked streams in the main stream
32153233
CUDA_CHECK(cudaEventRecord(concurrent_event->per_stream_events[i - 1],
32163234
cuda_ctx->stream(cuda_ctx->device, i)));
32173235
CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->per_stream_events[i - 1]));
@@ -3230,24 +3248,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
32303248

32313249
//the previous node was fused
32323250
const ggml_tensor * prev_node = cgraph->nodes[i - 1];
3233-
if (stream_ctx.concurrent_events.find(prev_node) != stream_ctx.concurrent_events.end()) {
3234-
concurrent_event = &stream_ctx.concurrent_events[prev_node];
3235-
3236-
GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, prev_node->name);
3237-
3238-
cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
3239-
GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
3240-
CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
3241-
3242-
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3243-
cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
3244-
CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
3245-
}
3246-
3247-
is_concurrent_event_active = true;
3248-
cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
3249-
GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
3250-
}
3251+
try_launch_concurrent_event(prev_node);
32513252
}
32523253
prev_i = i;
32533254

@@ -3565,25 +3566,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
35653566
GGML_ASSERT(ok);
35663567

35673568
if (!is_concurrent_event_active) {
3568-
//const ggml_tensor * adjusted_node = node;
3569-
// the forking node may have been fused, e.g (RMS_NORM_MUL + MUL + ADD),
3570-
// we can safely use the previous node to check if it can be forked
3571-
if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
3572-
concurrent_event = &stream_ctx.concurrent_events[node];
3573-
3574-
GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
3575-
3576-
cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
3577-
GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
3578-
CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
3579-
3580-
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3581-
cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
3582-
CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
3583-
}
3584-
3585-
is_concurrent_event_active = true;
3586-
}
3569+
try_launch_concurrent_event(node);
35873570
}
35883571
}
35893572
}

0 commit comments

Comments
 (0)