Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cfloat>
#include <cstdio>
#include <string>
#include <unordered_map>
#include <vector>

#if defined(GGML_USE_HIP)
Expand Down Expand Up @@ -964,6 +965,38 @@ struct ggml_cuda_graph {
#endif
};

struct ggml_cuda_concurrent_event {
std::vector<cudaEvent_t> join_events;
cudaEvent_t fork_event;

int n_streams = 0;
std::unordered_map<const ggml_tensor *, int> stream_mapping;

const ggml_tensor * join_node;

ggml_cuda_concurrent_event() = default;

explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
join_events.resize(n_streams);

for (size_t i = 0; i < join_events.size(); ++i) {
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
}

CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
}
};

struct ggml_cuda_stream_context {
std::vector<const ggml_tensor *> original_graph;
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;

void reset() {
original_graph.clear();
concurrent_events.clear();
}
};

struct ggml_backend_cuda_context {
int device;
std::string name;
Expand All @@ -974,11 +1007,15 @@ struct ggml_backend_cuda_context {

std::unique_ptr<ggml_cuda_graph> cuda_graph;

int curr_stream_no = 0;

explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}

ggml_cuda_stream_context concurrent_stream_context;

~ggml_backend_cuda_context();

cudaStream_t stream(int device, int stream) {
Expand All @@ -989,9 +1026,9 @@ struct ggml_backend_cuda_context {
return streams[device][stream];
}

cudaStream_t stream() {
return stream(device, 0);
}
cudaStream_t stream() { return stream(device, curr_stream_no); }

ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }

cublasHandle_t cublas_handle(int device) {
if (cublas_handles[device] == nullptr) {
Expand All @@ -1007,15 +1044,15 @@ struct ggml_backend_cuda_context {
}

// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];

static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);

ggml_cuda_pool & pool(int device) {
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(device);
if (pools[device][curr_stream_no] == nullptr) {
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
}
return *pools[device];
return *pools[device][curr_stream_no];
}

ggml_cuda_pool & pool() {
Expand Down
Loading
Loading