Skip to content

Commit 0e1d33c

Browse files
ikawrakowIwan Kawrakow
andauthored
Fuse add+add+fused_rms (ikawrakow#853)
* Fuse add+add+fused_rms * Try this * Macro to easily enable/disable fusion * Various: * Check that all tensors involved are on the same device before applying fusion * Fuse sigmoid+scale+sum_rows+div * Fix the fused bailingmoe2 experts selection The issue there was that the bias was not per row, but per expert group, so only the first n_per_group biases were used for al experts. --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 8aa3c2e commit 0e1d33c

File tree

8 files changed

+281
-54
lines changed

8 files changed

+281
-54
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3094,12 +3094,28 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
30943094

30953095
}
30963096

3097+
static inline bool ops_are_same_device(const ggml_cgraph * cgraph, int first, int last) {
3098+
if (last <= first) return true;
3099+
int device = ((const ggml_backend_cuda_buffer_context *)cgraph->nodes[first]->buffer->context)->device;
3100+
for (int i = first; i <= last; ++i) {
3101+
auto node = cgraph->nodes[i];
3102+
if (((const ggml_backend_cuda_buffer_context *)node->buffer->context)->device != device) return false;
3103+
for (int j = 0; j < GGML_MAX_SRC; ++j) {
3104+
if (!node->src[j] || !node->src[j]->buffer) continue;
3105+
if (((const ggml_backend_cuda_buffer_context *)node->src[j]->buffer->context)->device != device) return false;
3106+
}
3107+
}
3108+
return true;
3109+
}
3110+
30973111
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, const ggml_cgraph * cgraph, int & i) {
30983112
// why is this here instead of mul_mat?
30993113
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
31003114
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
31013115
}
31023116

3117+
#define ENABLE_FUSION true
3118+
31033119
#if IK_PRINT_TIMING
31043120
int64_t tim1 = ggml_time_us();
31053121
#endif
@@ -3129,17 +3145,32 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31293145
ggml_cuda_dup(ctx, dst);
31303146
break;
31313147
case GGML_OP_ADD:
3132-
if (i + 1 < cgraph->n_nodes &&
3148+
if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
3149+
cgraph->nodes[i+1]->op == GGML_OP_ADD &&
3150+
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
3151+
ggml_is_contiguous(dst->src[0]) &&
3152+
ggml_is_contiguous(dst->src[1]) &&
3153+
ggml_are_same_shape(dst->src[0], dst->src[1]) &&
3154+
dst == cgraph->nodes[i+1]->src[0] &&
3155+
ggml_is_contiguous(cgraph->nodes[i+1]->src[1]) &&
3156+
ggml_are_same_shape(dst, cgraph->nodes[i+1]->src[1]) &&
3157+
cgraph->nodes[i+1] == cgraph->nodes[i+2]->src[0] &&
3158+
ops_are_same_device(cgraph, i, i+2)) {
3159+
//printf("Fusing add->add->fused_rms of %s, %s, %s\n", dst->name, cgraph->nodes[i+1]->name, cgraph->nodes[i+2]->name);
3160+
ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3161+
i += 2;
3162+
}
3163+
else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
31333164
cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM &&
31343165
ggml_is_contiguous(dst->src[0]) &&
31353166
ggml_is_contiguous(dst->src[1]) &&
3136-
ggml_are_same_shape(dst->src[0], dst->src[1])) {
3167+
ggml_are_same_shape(dst->src[0], dst->src[1]) &&
3168+
dst == cgraph->nodes[i+1]->src[0] && ops_are_same_device(cgraph, i, i+1)) {
31373169
ggml_cuda_op_fused_add_rms_norm(ctx, dst, cgraph->nodes[i+1]);
31383170
++i;
31393171
} else {
31403172
ggml_cuda_op_add(ctx, dst);
31413173
}
3142-
//ggml_cuda_op_add(ctx, dst);
31433174
break;
31443175
case GGML_OP_ADD_ID:
31453176
ggml_cuda_op_add_id(ctx, dst);
@@ -3183,22 +3214,27 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31833214
ggml_cuda_op_relu(ctx, dst);
31843215
break;
31853216
case GGML_UNARY_OP_SIGMOID:
3186-
if (i + 5 < cgraph->n_nodes &&
3217+
if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
31873218
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
31883219
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
31893220
cgraph->nodes[i+3]->op == GGML_OP_ARGSORT &&
31903221
cgraph->nodes[i+4]->op == GGML_OP_VIEW &&
3191-
cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS) {
3222+
cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS && ops_are_same_device(cgraph, i, i+5)) {
31923223
cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]);
31933224
i += 5;
31943225
}
3195-
else if (i + 4 < cgraph->n_nodes &&
3226+
else if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
31963227
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
31973228
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
31983229
cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK &&
3199-
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) {
3230+
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS && ops_are_same_device(cgraph, i, i+4)) {
32003231
cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+3]);
32013232
i += 4;
3233+
} else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
3234+
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
3235+
cgraph->nodes[i+2]->op == GGML_OP_ADD && ops_are_same_device(cgraph, i, i+2)) {
3236+
ggml_cuda_op_biased_sigmoid(ctx, cgraph->nodes[i+2]);
3237+
i += 2;
32023238
} else {
32033239
ggml_cuda_op_sigmoid(ctx, dst);
32043240
}
@@ -3309,12 +3345,13 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33093345
ggml_cuda_op_diag_mask_inf(ctx, dst);
33103346
break;
33113347
case GGML_OP_SOFT_MAX:
3312-
if (i + 4 < cgraph->n_nodes &&
3348+
if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
33133349
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
33143350
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
33153351
cgraph->nodes[i+3]->op == GGML_OP_VIEW &&
33163352
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS &&
3317-
ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4])) {
3353+
ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4]) &&
3354+
ops_are_same_device(cgraph, i, i+4)) {
33183355
ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+4], cgraph->nodes[i+3]);
33193356
i += 4;
33203357
} else {
@@ -3343,23 +3380,32 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33433380
ggml_cuda_op_pool2d(ctx, dst);
33443381
break;
33453382
case GGML_OP_SUM_ROWS:
3346-
if (i + 1 < cgraph->n_nodes &&
3383+
if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
3384+
cgraph->nodes[i+1]->op == GGML_OP_SCALE &&
3385+
cgraph->nodes[i+2]->op == GGML_OP_DIV &&
3386+
cgraph->nodes[i+1]->src[0] == dst &&
3387+
cgraph->nodes[i+2]->src[1] == cgraph->nodes[i+1] &&
3388+
cgraph->nodes[i+2]->src[0] == dst->src[0] && ops_are_same_device(cgraph, i, i+2)) {
3389+
ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+2]);
3390+
i += 2;
3391+
}
3392+
else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
33473393
cgraph->nodes[i+1]->op == GGML_OP_DIV &&
33483394
cgraph->nodes[i+1]->src[1] == dst &&
3349-
cgraph->nodes[i+1]->src[0] == dst->src[0]) {
3395+
cgraph->nodes[i+1]->src[0] == dst->src[0] && ops_are_same_device(cgraph, i, i+1)) {
33503396
ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+1]);
33513397
++i;
33523398
} else {
33533399
ggml_cuda_op_sum_rows(ctx, dst);
33543400
}
33553401
break;
33563402
case GGML_OP_ARGSORT:
3357-
if (i + 5 < cgraph->n_nodes &&
3403+
if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
33583404
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
33593405
cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS &&
33603406
cgraph->nodes[i+3]->op == GGML_OP_RESHAPE &&
33613407
cgraph->nodes[i+4]->op == GGML_OP_SOFT_MAX &&
3362-
cgraph->nodes[i+5]->op == GGML_OP_RESHAPE) {
3408+
cgraph->nodes[i+5]->op == GGML_OP_RESHAPE && ops_are_same_device(cgraph, i, i+4)) {
33633409
cuda_openai_experts(ctx, dst, cgraph->nodes[i+4]);
33643410
i += 5;
33653411
} else {
@@ -3390,6 +3436,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33903436
printf("%s(%s): %d us\n", ggml_op_name(dst->op), dst->name, (int)(tim2 - tim1));
33913437
#endif
33923438

3439+
#undef ENABLE_FUSION
3440+
33933441
return true;
33943442
}
33953443

0 commit comments

Comments
 (0)