@@ -984,6 +984,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
984984 "CROSS_ENTROPY_LOSS" ,
985985 "CROSS_ENTROPY_LOSS_BACK" ,
986986 "OPT_STEP_ADAMW" ,
987+
988+ "GLU" ,
987989};
988990
989991static_assert (GGML_OP_COUNT == 83 , "GGML_OP_COUNT != 83" );
@@ -1080,6 +1082,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10801082 "cross_entropy_loss(x,y)" ,
10811083 "cross_entropy_loss_back(x,y)" ,
10821084 "adamw(x)" ,
1085+
1086+ "glu(x)" ,
10831087};
10841088
10851089static_assert (GGML_OP_COUNT == 83 , "GGML_OP_COUNT != 83" );
@@ -1103,12 +1107,18 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
11031107 "HARDSIGMOID" ,
11041108 "EXP" ,
11051109 "GELU_ERF" ,
1110+ };
1111+
1112+ static_assert (GGML_UNARY_OP_COUNT == 15 , "GGML_UNARY_OP_COUNT != 15" );
1113+
1114+
1115+ static const char * GGML_GLU_OP_NAME [GGML_GLU_OP_COUNT ] = {
11061116 "REGLU" ,
11071117 "GEGLU" ,
11081118 "SWIGLU" ,
11091119};
11101120
1111- static_assert (GGML_UNARY_OP_COUNT == 18 , "GGML_UNARY_OP_COUNT != 18 " );
1121+ static_assert (GGML_GLU_OP_COUNT == 3 , "GGML_GLU_OP_COUNT != 3 " );
11121122
11131123
11141124static_assert (sizeof (struct ggml_object )%GGML_MEM_ALIGN == 0 , "ggml_object size must be a multiple of GGML_MEM_ALIGN" );
@@ -1213,11 +1223,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
12131223 return GGML_UNARY_OP_NAME [op ];
12141224}
12151225
1226+ const char * ggml_glu_op_name (enum ggml_glu_op op ) {
1227+ return GGML_GLU_OP_NAME [op ];
1228+ }
1229+
12161230const char * ggml_op_desc (const struct ggml_tensor * t ) {
12171231 if (t -> op == GGML_OP_UNARY ) {
12181232 enum ggml_unary_op uop = ggml_get_unary_op (t );
12191233 return ggml_unary_op_name (uop );
12201234 }
1235+ if (t -> op == GGML_OP_GLU ) {
1236+ enum ggml_glu_op gop = ggml_get_glu_op (t );
1237+ return ggml_glu_op_name (gop );
1238+ }
12211239 return ggml_op_name (t -> op );
12221240}
12231241
@@ -1736,6 +1754,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
17361754 return (enum ggml_unary_op ) ggml_get_op_params_i32 (tensor , 0 );
17371755}
17381756
1757+ enum ggml_glu_op ggml_get_glu_op (const struct ggml_tensor * tensor ) {
1758+ GGML_ASSERT (tensor -> op == GGML_OP_GLU );
1759+ return (enum ggml_glu_op ) ggml_get_op_params_i32 (tensor , 0 );
1760+ }
1761+
17391762const char * ggml_get_name (const struct ggml_tensor * tensor ) {
17401763 return tensor -> name ;
17411764}
@@ -2615,58 +2638,47 @@ struct ggml_tensor * ggml_exp_inplace(
26152638 return ggml_unary_inplace (ctx , a , GGML_UNARY_OP_EXP );
26162639}
26172640
2618- // ggml_reglu
2641+ // ggml_glu
26192642
2620- struct ggml_tensor * ggml_reglu (
2643+ struct ggml_tensor * ggml_glu (
26212644 struct ggml_context * ctx ,
2622- struct ggml_tensor * a ) {
2645+ struct ggml_tensor * a ,
2646+ enum ggml_glu_op op ) {
26232647 GGML_ASSERT (ggml_is_contiguous_1 (a ));
26242648
26252649 int64_t ne [GGML_MAX_DIMS ] = { a -> ne [0 ] / 2 }; for (int i = 1 ; i < GGML_MAX_DIMS ; i ++ ) ne [i ] = a -> ne [i ];
26262650 struct ggml_tensor * result = ggml_new_tensor_impl (ctx , a -> type , GGML_MAX_DIMS , ne , NULL , 0 );
26272651
2628- ggml_set_op_params_i32 (result , 0 , (int32_t ) GGML_UNARY_OP_REGLU );
2652+ ggml_set_op_params_i32 (result , 0 , (int32_t ) op );
26292653
2630- result -> op = GGML_OP_UNARY ;
2654+ result -> op = GGML_OP_GLU ;
26312655 result -> src [0 ] = a ;
26322656
26332657 return result ;
26342658}
26352659
2636- // ggml_geglu
2660+ // ggml_reglu
26372661
2638- struct ggml_tensor * ggml_geglu (
2662+ struct ggml_tensor * ggml_reglu (
26392663 struct ggml_context * ctx ,
26402664 struct ggml_tensor * a ) {
2641- GGML_ASSERT (ggml_is_contiguous_1 (a ));
2642-
2643- int64_t ne [GGML_MAX_DIMS ] = { a -> ne [0 ] / 2 }; for (int i = 1 ; i < GGML_MAX_DIMS ; i ++ ) ne [i ] = a -> ne [i ];
2644- struct ggml_tensor * result = ggml_new_tensor_impl (ctx , a -> type , GGML_MAX_DIMS , ne , NULL , 0 );
2645-
2646- ggml_set_op_params_i32 (result , 0 , (int32_t ) GGML_UNARY_OP_GEGLU );
2665+ return ggml_glu (ctx , a , GGML_GLU_OP_REGLU );
2666+ }
26472667
2648- result -> op = GGML_OP_UNARY ;
2649- result -> src [0 ] = a ;
2668+ // ggml_geglu
26502669
2651- return result ;
2670+ struct ggml_tensor * ggml_geglu (
2671+ struct ggml_context * ctx ,
2672+ struct ggml_tensor * a ) {
2673+ return ggml_glu (ctx , a , GGML_GLU_OP_GEGLU );
26522674}
26532675
26542676// ggml_swiglu
26552677
26562678struct ggml_tensor * ggml_swiglu (
26572679 struct ggml_context * ctx ,
26582680 struct ggml_tensor * a ) {
2659- GGML_ASSERT (ggml_is_contiguous_1 (a ));
2660-
2661- int64_t ne [GGML_MAX_DIMS ] = { a -> ne [0 ] / 2 }; for (int i = 1 ; i < GGML_MAX_DIMS ; i ++ ) ne [i ] = a -> ne [i ];
2662- struct ggml_tensor * result = ggml_new_tensor_impl (ctx , a -> type , GGML_MAX_DIMS , ne , NULL , 0 );
2663-
2664- ggml_set_op_params_i32 (result , 0 , (int32_t ) GGML_UNARY_OP_SWIGLU );
2665-
2666- result -> op = GGML_OP_UNARY ;
2667- result -> src [0 ] = a ;
2668-
2669- return result ;
2681+ return ggml_glu (ctx , a , GGML_GLU_OP_SWIGLU );
26702682}
26712683
26722684// ggml_norm
0 commit comments