Skip to content

Commit b8b08e3

Browse files
committed
ggml-cuda: rename + remove unused var
1 parent d385f76 commit b8b08e3

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -966,9 +966,8 @@ struct ggml_cuda_graph {
966966
};
967967

968968
struct ggml_cuda_concurrent_event {
969-
std::vector<cudaEvent_t> per_stream_events;
969+
std::vector<cudaEvent_t> join_events;
970970
cudaEvent_t fork_event;
971-
cudaEvent_t join_event;
972971

973972
int n_streams = 0;
974973
std::unordered_map<const ggml_tensor *, int> stream_mapping;
@@ -978,14 +977,13 @@ struct ggml_cuda_concurrent_event {
978977
ggml_cuda_concurrent_event() = default;
979978

980979
explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
981-
per_stream_events.resize(n_streams);
980+
join_events.resize(n_streams);
982981

983-
for (size_t i = 0; i < per_stream_events.size(); ++i) {
984-
CUDA_CHECK(cudaEventCreateWithFlags(&per_stream_events[i], cudaEventDisableTiming));
982+
for (size_t i = 0; i < join_events.size(); ++i) {
983+
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
985984
}
986985

987986
CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
988-
CUDA_CHECK(cudaEventCreateWithFlags(&join_event, cudaEventDisableTiming));
989987
}
990988
};
991989

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3230,9 +3230,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
32303230
cuda_ctx->curr_stream_no = 0;
32313231
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
32323232
// Wait on join events of forked streams in the main stream
3233-
CUDA_CHECK(cudaEventRecord(concurrent_event->per_stream_events[i - 1],
3233+
CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
32343234
cuda_ctx->stream(cuda_ctx->device, i)));
3235-
CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->per_stream_events[i - 1]));
3235+
CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
32363236
}
32373237

32383238
is_concurrent_event_active = false;

0 commit comments

Comments
 (0)