@@ -1017,10 +1017,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10171017 "OPT_STEP_SGD" ,
10181018
10191019 "GLU" ,
1020- "XIELU" ,
10211020};
10221021
1023- static_assert (GGML_OP_COUNT == 91 , "GGML_OP_COUNT != 90" );
1022+ static_assert (GGML_OP_COUNT == 90 , "GGML_OP_COUNT != 90" );
10241023
10251024static const char * GGML_OP_SYMBOL [GGML_OP_COUNT ] = {
10261025 "none" ,
@@ -1122,10 +1121,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11221121 "sgd(x)" ,
11231122
11241123 "glu(x)" ,
1125- "xielu(x)" ,
11261124};
11271125
1128- static_assert (GGML_OP_COUNT == 91 , "GGML_OP_COUNT != 90" );
1126+ static_assert (GGML_OP_COUNT == 90 , "GGML_OP_COUNT != 90" );
11291127
11301128static_assert (GGML_OP_POOL_COUNT == 2 , "GGML_OP_POOL_COUNT != 2" );
11311129
@@ -1145,9 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
11451143 "HARDSIGMOID" ,
11461144 "EXP" ,
11471145 "GELU_ERF" ,
1146+ "XIELU" ,
11481147};
11491148
1150- static_assert (GGML_UNARY_OP_COUNT == 15 , "GGML_UNARY_OP_COUNT != 15 " );
1149+ static_assert (GGML_UNARY_OP_COUNT == 16 , "GGML_UNARY_OP_COUNT != 16 " );
11511150
11521151static const char * GGML_GLU_OP_NAME [GGML_GLU_OP_COUNT ] = {
11531152 "REGLU" ,
@@ -2662,11 +2661,13 @@ struct ggml_tensor * ggml_xielu(
26622661 float eps ) {
26632662 struct ggml_tensor * result = ggml_dup_tensor (ctx , a );
26642663
2665- // Store the parameters as operation parameters
2666- float params [] = { beta + softplus (alpha_n ), softplus (alpha_p ), beta , eps };
2667- ggml_set_op_params (result , params , sizeof (params ));
2664+ ggml_set_op_params_i32 (result , 0 , (int32_t ) GGML_UNARY_OP_XIELU );
2665+ ggml_set_op_params_f32 (result , 1 , beta + softplus (alpha_n ));
2666+ ggml_set_op_params_f32 (result , 2 , softplus (alpha_p ));
2667+ ggml_set_op_params_f32 (result , 3 , beta );
2668+ ggml_set_op_params_f32 (result , 4 , eps );
26682669
2669- result -> op = GGML_OP_XIELU ;
2670+ result -> op = GGML_OP_UNARY ;
26702671 result -> src [0 ] = a ;
26712672
26722673 return result ;
@@ -7236,4 +7237,4 @@ bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, cons
72367237 if (p0 -> poll != p1 -> poll ) return false;
72377238 if (p0 -> strict_cpu != p1 -> strict_cpu ) return false;
72387239 return memcmp (p0 -> cpumask , p1 -> cpumask , GGML_MAX_N_THREADS ) == 0 ;
7239- }
7240+ }
0 commit comments