Skip to content
53 changes: 45 additions & 8 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cfloat>
#include <cstdio>
#include <string>
#include <unordered_map>
#include <vector>

#if defined(GGML_USE_HIP)
Expand Down Expand Up @@ -964,6 +965,38 @@ struct ggml_cuda_graph {
#endif
};

struct ggml_cuda_concurrent_event {
std::vector<cudaEvent_t> join_events;
cudaEvent_t fork_event;

int n_streams = 0;
std::unordered_map<const ggml_tensor *, int> stream_mapping;

const ggml_tensor * join_node;

ggml_cuda_concurrent_event() = default;

explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
join_events.resize(n_streams);

for (size_t i = 0; i < join_events.size(); ++i) {
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
}

CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
}
};

struct ggml_cuda_stream_context {
std::vector<const ggml_tensor *> original_graph;
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;

void reset() {
original_graph.clear();
concurrent_events.clear();
}
};

struct ggml_backend_cuda_context {
int device;
std::string name;
Expand All @@ -974,11 +1007,15 @@ struct ggml_backend_cuda_context {

std::unique_ptr<ggml_cuda_graph> cuda_graph;

int curr_stream_no = 0;

explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}

ggml_cuda_stream_context concurrent_stream_context;

~ggml_backend_cuda_context();

cudaStream_t stream(int device, int stream) {
Expand All @@ -989,9 +1026,9 @@ struct ggml_backend_cuda_context {
return streams[device][stream];
}

cudaStream_t stream() {
return stream(device, 0);
}
cudaStream_t stream() { return stream(device, curr_stream_no); }

ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }

cublasHandle_t cublas_handle(int device) {
if (cublas_handles[device] == nullptr) {
Expand All @@ -1007,15 +1044,15 @@ struct ggml_backend_cuda_context {
}

// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];

static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);

ggml_cuda_pool & pool(int device) {
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(device);
if (pools[device][curr_stream_no] == nullptr) {
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
}
return *pools[device];
return *pools[device][curr_stream_no];
}

ggml_cuda_pool & pool() {
Expand Down
Loading