Skip to content

Commit 415b825

Browse files
authored
undo switch block
ggml-ci
1 parent 656584f commit 415b825

File tree

1 file changed

+40
-43
lines changed

1 file changed

+40
-43
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2772,57 +2772,54 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
27722772
return false;
27732773
}
27742774

2775-
switch (ops.size()) {
2776-
case 2:
2777-
if (ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2778-
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2779-
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2780-
2781-
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2782-
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
2783-
2784-
//rms norm only supports F32
2785-
if (mul->src[0]->type != GGML_TYPE_F32 ||
2786-
mul->src[1]->type != GGML_TYPE_F32 ||
2787-
mul->type != GGML_TYPE_F32) {
2788-
return false;
2789-
}
2775+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2776+
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2777+
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
27902778

2791-
//if rms norm is the B operand, then we don't handle broadcast
2792-
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2793-
return false;
2794-
}
2779+
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2780+
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
27952781

2796-
//rms_norm kernel assumes contigous rows
2797-
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2798-
return false;
2799-
}
2800-
}
2801-
break;
2802-
case 3:
2803-
if (ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
2804-
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
2805-
const ggml_tensor *scale = cgraph->nodes[node_idx];
2806-
const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
2807-
const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
2782+
//rms norm only supports F32
2783+
if (mul->src[0]->type != GGML_TYPE_F32 ||
2784+
mul->src[1]->type != GGML_TYPE_F32 ||
2785+
mul->type != GGML_TYPE_F32) {
2786+
return false;
2787+
}
28082788

2809-
GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
2810-
GGML_ASSERT(scale->type == GGML_TYPE_F32);
2789+
//if rms norm is the B operand, then we don't handle broadcast
2790+
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2791+
return false;
2792+
}
28112793

2812-
if (tanh->src[0] != scale || scale2->src[0] != tanh) {
2813-
return false;
2814-
}
2794+
//rms_norm kernel assumes contigous rows
2795+
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2796+
return false;
2797+
}
28152798

2816-
if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
2817-
return false;
2818-
}
2819-
}
2820-
break;
2821-
default:
2799+
return true;
2800+
}
2801+
2802+
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
2803+
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
2804+
const ggml_tensor *scale = cgraph->nodes[node_idx];
2805+
const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
2806+
const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
2807+
2808+
GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
2809+
GGML_ASSERT(scale->type == GGML_TYPE_F32);
2810+
2811+
if (tanh->src[0] != scale || scale2->src[0] != tanh) {
28222812
return false;
2813+
}
2814+
2815+
if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
2816+
return false;
2817+
}
2818+
2819+
return true;
28232820
}
28242821

2825-
return true;
2822+
return false;
28262823
}
28272824

28282825
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,

0 commit comments

Comments
 (0)