Skip to content

Commit 788a6a2

Browse files
committed
ggml-cuda: fix fusion inside stream
1 parent 95da850 commit 788a6a2

File tree

2 files changed

+54
-21
lines changed

2 files changed

+54
-21
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,15 @@ struct ggml_cuda_concurrent_event {
983983
}
984984
};
985985

986-
using ggml_cuda_stream_context = std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event>;
986+
struct ggml_cuda_stream_context {
987+
std::vector<const ggml_tensor *> original_graph;
988+
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
989+
990+
void reset() {
991+
original_graph.clear();
992+
concurrent_events.clear();
993+
}
994+
};
987995

988996
struct ggml_backend_cuda_context {
989997
int device;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3153,6 +3153,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31533153
[[maybe_unused]] int prev_i = 0;
31543154

31553155
ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
3156+
if (stream_ctx.concurrent_events.size() > 0) {
3157+
cgraph->nodes = const_cast<ggml_tensor **>(stream_ctx.original_graph.data());
3158+
}
31563159

31573160
for (int i = 0; i < cgraph->n_nodes; i++) {
31583161
ggml_tensor * node = cgraph->nodes[i];
@@ -3176,6 +3179,26 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31763179
cuda_ctx->curr_stream_no = stream_mapping;
31773180
GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", stream_mapping, node->name);
31783181
}
3182+
} else if (i - prev_i > 1) {
3183+
3184+
//the previous node was fused
3185+
const ggml_tensor * prev_node = cgraph->nodes[i - 1];
3186+
if (stream_ctx.concurrent_events.find(prev_node) != stream_ctx.concurrent_events.end()) {
3187+
concurrent_event = &stream_ctx.concurrent_events[prev_node];
3188+
3189+
GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
3190+
3191+
cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
3192+
GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
3193+
CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
3194+
3195+
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3196+
cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
3197+
CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
3198+
}
3199+
3200+
is_concurrent_event_active = true;
3201+
}
31793202
}
31803203
prev_i = i;
31813204

@@ -3446,23 +3469,16 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
34463469
continue;
34473470
}
34483471

3449-
//TODO: fix this
3450-
static const bool graph_opt = (getenv("GGML_CUDA_GRAPH_OPT") != nullptr) && atoi(getenv("GGML_CUDA_GRAPH_OPT")) == 1;
3451-
34523472
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3453-
if (strncmp(cgraph->nodes[i+2]->name, "attn_norm", strlen("attn_norm")) != 0 || !graph_opt) {
3454-
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3455-
i += 2;
3456-
continue;
3457-
}
3473+
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3474+
i += 2;
3475+
continue;
34583476
}
34593477

34603478
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
3461-
if (strncmp(cgraph->nodes[i+1]->name, "attn_norm", strlen("attn_norm")) != 0 || !graph_opt) {
3462-
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3463-
i++;
3464-
continue;
3465-
}
3479+
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3480+
i++;
3481+
continue;
34663482
}
34673483

34683484
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
@@ -3494,8 +3510,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
34943510
//const ggml_tensor * adjusted_node = node;
34953511
// the forking node may have been fused, e.g (RMS_NORM_MUL + MUL + ADD),
34963512
// we can safely use the previous node to check if it can be forked
3497-
if (stream_ctx.find(node) != stream_ctx.end()) {
3498-
concurrent_event = &stream_ctx[node];
3513+
if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
3514+
concurrent_event = &stream_ctx.concurrent_events[node];
34993515

35003516
GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
35013517

@@ -3666,7 +3682,7 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
36663682
GGML_LOG_DEBUG("Optimizing CUDA graph %p %d\n", cgraph->nodes, cgraph->n_nodes);
36673683

36683684
ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
3669-
stream_context.clear();
3685+
stream_context.reset();
36703686

36713687
std::unordered_map<const ggml_tensor *, int> fan_out;
36723688
std::unordered_map<const ggml_tensor *, int> node_indices;
@@ -3712,6 +3728,15 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
37123728
const int max_fan_out = 3;
37133729

37143730
std::vector<std::pair<int, int>> concurrent_node_ranges;
3731+
3732+
//save the original graph
3733+
std::vector<const ggml_tensor *> original_graph;
3734+
original_graph.reserve(cgraph->n_nodes);
3735+
for (int i = 0; i < cgraph->n_nodes; ++i) {
3736+
original_graph.push_back(cgraph->nodes[i]);
3737+
}
3738+
cuda_ctx->stream_context().original_graph = std::move(original_graph);
3739+
37153740
for (const auto & [root_node, count] : fan_out) {
37163741
if (count >= min_fan_out && count <= max_fan_out) {
37173742
const int root_node_idx = node_indices[root_node];
@@ -3740,7 +3765,7 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
37403765
//find the join point
37413766
const ggml_tensor * join_node = nullptr;
37423767

3743-
auto belongs_to_branch = [&](const ggml_tensor * node, std::vector<const ggml_tensor *> & branch) -> bool {
3768+
const auto & belongs_to_branch = [&](const ggml_tensor * node, std::vector<const ggml_tensor *> & branch) -> bool {
37443769
for (const ggml_tensor * n : branch) {
37453770
if (n == node) {
37463771
return false;
@@ -3823,16 +3848,16 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
38233848
continue;
38243849
}
38253850

3826-
GGML_ASSERT(cuda_ctx->stream_context().find(root_node) == cuda_ctx->stream_context().end());
3827-
cuda_ctx->stream_context().emplace(root_node, concurrent_event);
3851+
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
3852+
GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
3853+
concurrent_events.emplace(root_node, concurrent_event);
38283854
GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
38293855
concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
38303856

38313857
// interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
38323858
// example transformation:
38333859
// [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
38343860
// [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
3835-
// TODO: This breaks fusion within streams, how do we fix this?
38363861
while (current_node_idx < join_node_idx) {
38373862
std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
38383863

0 commit comments

Comments
 (0)