Skip to content

Commit aa40944

Browse files
committed
Revert "Fuse add+add+fused_rms (ikawrakow#853)"
This reverts commit 0e1d33c.
1 parent fc6fb76 commit aa40944

File tree

8 files changed

+54
-281
lines changed

8 files changed

+54
-281
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 13 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3095,28 +3095,12 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
30953095

30963096
}
30973097

3098-
static inline bool ops_are_same_device(const ggml_cgraph * cgraph, int first, int last) {
3099-
if (last <= first) return true;
3100-
int device = ((const ggml_backend_cuda_buffer_context *)cgraph->nodes[first]->buffer->context)->device;
3101-
for (int i = first; i <= last; ++i) {
3102-
auto node = cgraph->nodes[i];
3103-
if (((const ggml_backend_cuda_buffer_context *)node->buffer->context)->device != device) return false;
3104-
for (int j = 0; j < GGML_MAX_SRC; ++j) {
3105-
if (!node->src[j] || !node->src[j]->buffer) continue;
3106-
if (((const ggml_backend_cuda_buffer_context *)node->src[j]->buffer->context)->device != device) return false;
3107-
}
3108-
}
3109-
return true;
3110-
}
3111-
31123098
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, const ggml_cgraph * cgraph, int & i) {
31133099
// why is this here instead of mul_mat?
31143100
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
31153101
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
31163102
}
31173103

3118-
#define ENABLE_FUSION true
3119-
31203104
#if IK_PRINT_TIMING
31213105
int64_t tim1 = ggml_time_us();
31223106
#endif
@@ -3146,32 +3130,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31463130
ggml_cuda_dup(ctx, dst);
31473131
break;
31483132
case GGML_OP_ADD:
3149-
if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
3150-
cgraph->nodes[i+1]->op == GGML_OP_ADD &&
3151-
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
3152-
ggml_is_contiguous(dst->src[0]) &&
3153-
ggml_is_contiguous(dst->src[1]) &&
3154-
ggml_are_same_shape(dst->src[0], dst->src[1]) &&
3155-
dst == cgraph->nodes[i+1]->src[0] &&
3156-
ggml_is_contiguous(cgraph->nodes[i+1]->src[1]) &&
3157-
ggml_are_same_shape(dst, cgraph->nodes[i+1]->src[1]) &&
3158-
cgraph->nodes[i+1] == cgraph->nodes[i+2]->src[0] &&
3159-
ops_are_same_device(cgraph, i, i+2)) {
3160-
//printf("Fusing add->add->fused_rms of %s, %s, %s\n", dst->name, cgraph->nodes[i+1]->name, cgraph->nodes[i+2]->name);
3161-
ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3162-
i += 2;
3163-
}
3164-
else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
3133+
if (i + 1 < cgraph->n_nodes &&
31653134
cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM &&
31663135
ggml_is_contiguous(dst->src[0]) &&
31673136
ggml_is_contiguous(dst->src[1]) &&
3168-
ggml_are_same_shape(dst->src[0], dst->src[1]) &&
3169-
dst == cgraph->nodes[i+1]->src[0] && ops_are_same_device(cgraph, i, i+1)) {
3137+
ggml_are_same_shape(dst->src[0], dst->src[1])) {
31703138
ggml_cuda_op_fused_add_rms_norm(ctx, dst, cgraph->nodes[i+1]);
31713139
++i;
31723140
} else {
31733141
ggml_cuda_op_add(ctx, dst);
31743142
}
3143+
//ggml_cuda_op_add(ctx, dst);
31753144
break;
31763145
case GGML_OP_ADD_ID:
31773146
ggml_cuda_op_add_id(ctx, dst);
@@ -3218,27 +3187,22 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
32183187
ggml_cuda_op_relu(ctx, dst);
32193188
break;
32203189
case GGML_UNARY_OP_SIGMOID:
3221-
if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
3190+
if (i + 5 < cgraph->n_nodes &&
32223191
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
32233192
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
32243193
cgraph->nodes[i+3]->op == GGML_OP_ARGSORT &&
32253194
cgraph->nodes[i+4]->op == GGML_OP_VIEW &&
3226-
cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS && ops_are_same_device(cgraph, i, i+5)) {
3195+
cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS) {
32273196
cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]);
32283197
i += 5;
32293198
}
3230-
else if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
3199+
else if (i + 4 < cgraph->n_nodes &&
32313200
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
32323201
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
32333202
cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK &&
3234-
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS && ops_are_same_device(cgraph, i, i+4)) {
3203+
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) {
32353204
cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+3]);
32363205
i += 4;
3237-
} else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
3238-
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
3239-
cgraph->nodes[i+2]->op == GGML_OP_ADD && ops_are_same_device(cgraph, i, i+2)) {
3240-
ggml_cuda_op_biased_sigmoid(ctx, cgraph->nodes[i+2]);
3241-
i += 2;
32423206
} else {
32433207
ggml_cuda_op_sigmoid(ctx, dst);
32443208
}
@@ -3349,13 +3313,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33493313
ggml_cuda_op_diag_mask_inf(ctx, dst);
33503314
break;
33513315
case GGML_OP_SOFT_MAX:
3352-
if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
3316+
if (i + 4 < cgraph->n_nodes &&
33533317
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
33543318
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
33553319
cgraph->nodes[i+3]->op == GGML_OP_VIEW &&
33563320
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS &&
3357-
ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4]) &&
3358-
ops_are_same_device(cgraph, i, i+4)) {
3321+
ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4])) {
33593322
ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+4], cgraph->nodes[i+3]);
33603323
i += 4;
33613324
} else {
@@ -3387,32 +3350,23 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
33873350
ggml_cuda_op_pool2d(ctx, dst);
33883351
break;
33893352
case GGML_OP_SUM_ROWS:
3390-
if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
3391-
cgraph->nodes[i+1]->op == GGML_OP_SCALE &&
3392-
cgraph->nodes[i+2]->op == GGML_OP_DIV &&
3393-
cgraph->nodes[i+1]->src[0] == dst &&
3394-
cgraph->nodes[i+2]->src[1] == cgraph->nodes[i+1] &&
3395-
cgraph->nodes[i+2]->src[0] == dst->src[0] && ops_are_same_device(cgraph, i, i+2)) {
3396-
ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+2]);
3397-
i += 2;
3398-
}
3399-
else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
3353+
if (i + 1 < cgraph->n_nodes &&
34003354
cgraph->nodes[i+1]->op == GGML_OP_DIV &&
34013355
cgraph->nodes[i+1]->src[1] == dst &&
3402-
cgraph->nodes[i+1]->src[0] == dst->src[0] && ops_are_same_device(cgraph, i, i+1)) {
3356+
cgraph->nodes[i+1]->src[0] == dst->src[0]) {
34033357
ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+1]);
34043358
++i;
34053359
} else {
34063360
ggml_cuda_op_sum_rows(ctx, dst);
34073361
}
34083362
break;
34093363
case GGML_OP_ARGSORT:
3410-
if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
3364+
if (i + 5 < cgraph->n_nodes &&
34113365
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
34123366
cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS &&
34133367
cgraph->nodes[i+3]->op == GGML_OP_RESHAPE &&
34143368
cgraph->nodes[i+4]->op == GGML_OP_SOFT_MAX &&
3415-
cgraph->nodes[i+5]->op == GGML_OP_RESHAPE && ops_are_same_device(cgraph, i, i+4)) {
3369+
cgraph->nodes[i+5]->op == GGML_OP_RESHAPE) {
34163370
cuda_openai_experts(ctx, dst, cgraph->nodes[i+4]);
34173371
i += 5;
34183372
} else {
@@ -3443,8 +3397,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
34433397
printf("%s(%s): %d us\n", ggml_op_name(dst->op), dst->name, (int)(tim2 - tim1));
34443398
#endif
34453399

3446-
#undef ENABLE_FUSION
3447-
34483400
return true;
34493401
}
34503402

0 commit comments

Comments
 (0)