Skip to content

Commit bacf839

Browse files
authored
completely forgot to check the unary op
1 parent 4ec0e68 commit bacf839

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2807,11 +2807,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28072807
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
28082808
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
28092809
const ggml_tensor *scale = cgraph->nodes[node_idx];
2810+
const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
28102811
const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
28112812

28122813
GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
28132814
GGML_ASSERT(scale->type == GGML_TYPE_F32);
28142815

2816+
if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
2817+
return false;
2818+
}
2819+
28152820
// Check for bias
28162821
if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
28172822
return false;

0 commit comments

Comments
 (0)