@@ -52,6 +52,15 @@ static inline float op_sqrt(float x) {
5252 return sqrtf (x);
5353}
5454
55+ static inline float op_xielu (float x, float alpha_n, float alpha_p, float beta, float eps) {
56+ if (x > 0 .0f ) {
57+ return alpha_p * x * x + beta * x;
58+ } else {
59+ const float min_x_eps = fminf (x, eps);
60+ return (expm1f (min_x_eps) - x) * alpha_n + beta * x;
61+ }
62+ }
63+
5564static inline float op_sin (float x) {
5665 return sinf (x);
5766}
@@ -64,8 +73,8 @@ static inline float op_log(float x) {
6473 return logf (x);
6574}
6675
67- template <float (*op)( float ) , typename src0_t , typename dst_t >
68- static inline void vec_unary_op (int64_t n, dst_t * y, const src0_t * x) {
76+ template <typename Op , typename src0_t , typename dst_t >
77+ static inline void vec_unary_op (const Op& op, int64_t n, dst_t * y, const src0_t * x) {
6978 constexpr auto src0_to_f32 = type_conversion_table<src0_t >::to_f32;
7079 constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
7180
@@ -74,8 +83,8 @@ static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
7483 }
7584}
7685
77- template <float (*op)( float ) , typename src0_t , typename dst_t >
78- static void apply_unary_op (const ggml_compute_params * params, ggml_tensor * dst) {
86+ template <typename Op , typename src0_t , typename dst_t >
87+ static void apply_unary_op (const Op& op, const ggml_compute_params * params, ggml_tensor * dst) {
7988 const ggml_tensor * src0 = dst->src [0 ];
8089
8190 GGML_ASSERT (ggml_is_contiguous_1 (src0) && ggml_is_contiguous_1 (dst) && ggml_are_same_shape (src0, dst));
@@ -95,25 +104,25 @@ static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst
95104 dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
96105 const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
97106
98- vec_unary_op<op>( ne0, dst_ptr, src0_ptr);
107+ vec_unary_op<decltype (op), src0_t , dst_t >(op, ne0, dst_ptr, src0_ptr);
99108 }
100109}
101110
102111// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
103- template <float (*op)( float ) >
104- static void unary_op (const ggml_compute_params * params, ggml_tensor * dst) {
112+ template <typename Op >
113+ static void unary_op (const Op& op, const ggml_compute_params * params, ggml_tensor * dst) {
105114 const ggml_tensor * src0 = dst->src [0 ];
106115
107116 /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
108- apply_unary_op<op , float , float >(params, dst);
117+ apply_unary_op<decltype (op) , float , float >(op, params, dst);
109118 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
110- apply_unary_op<op , ggml_fp16_t , ggml_fp16_t >(params, dst);
119+ apply_unary_op<decltype (op) , ggml_fp16_t , ggml_fp16_t >(op, params, dst);
111120 } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
112- apply_unary_op<op , ggml_bf16_t , ggml_bf16_t >(params, dst);
121+ apply_unary_op<decltype (op) , ggml_bf16_t , ggml_bf16_t >(op, params, dst);
113122 } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
114- apply_unary_op<op , ggml_bf16_t , float >(params, dst);
123+ apply_unary_op<decltype (op) , ggml_bf16_t , float >(op, params, dst);
115124 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
116- apply_unary_op<op , ggml_fp16_t , float >(params, dst);
125+ apply_unary_op<decltype (op) , ggml_fp16_t , float >(op, params, dst);
117126 } else {
118127 fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s\n " , __func__,
119128 ggml_type_name (dst->type ), ggml_type_name (src0->type ));
@@ -122,65 +131,89 @@ static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
122131}
123132
124133void ggml_compute_forward_abs (const ggml_compute_params * params, ggml_tensor * dst) {
125- unary_op<op_abs>( params, dst);
134+ unary_op (op_abs, params, dst);
126135}
127136
128137void ggml_compute_forward_sgn (const ggml_compute_params * params, ggml_tensor * dst) {
129- unary_op<op_sgn>( params, dst);
138+ unary_op (op_sgn, params, dst);
130139}
131140
132141void ggml_compute_forward_neg (const ggml_compute_params * params, ggml_tensor * dst) {
133- unary_op<op_neg>( params, dst);
142+ unary_op (op_neg, params, dst);
134143}
135144
136145void ggml_compute_forward_step (const ggml_compute_params * params, ggml_tensor * dst) {
137- unary_op<op_step>( params, dst);
146+ unary_op (op_step, params, dst);
138147}
139148
140149void ggml_compute_forward_tanh (const ggml_compute_params * params, ggml_tensor * dst) {
141- unary_op<op_tanh>( params, dst);
150+ unary_op (op_tanh, params, dst);
142151}
143152
144153void ggml_compute_forward_elu (const ggml_compute_params * params, ggml_tensor * dst) {
145- unary_op<op_elu>( params, dst);
154+ unary_op (op_elu, params, dst);
146155}
147156
148157void ggml_compute_forward_relu (const ggml_compute_params * params, ggml_tensor * dst) {
149- unary_op<op_relu>( params, dst);
158+ unary_op (op_relu, params, dst);
150159}
151160
152161void ggml_compute_forward_sigmoid (const ggml_compute_params * params, ggml_tensor * dst) {
153- unary_op<op_sigmoid>( params, dst);
162+ unary_op (op_sigmoid, params, dst);
154163}
155164
156165void ggml_compute_forward_hardsigmoid (const ggml_compute_params * params, ggml_tensor * dst) {
157- unary_op<op_hardsigmoid>( params, dst);
166+ unary_op (op_hardsigmoid, params, dst);
158167}
159168
160169void ggml_compute_forward_exp (const ggml_compute_params * params, ggml_tensor * dst) {
161- unary_op<op_exp>( params, dst);
170+ unary_op (op_exp, params, dst);
162171}
163172
164173void ggml_compute_forward_hardswish (const ggml_compute_params * params, ggml_tensor * dst) {
165- unary_op<op_hardswish>( params, dst);
174+ unary_op (op_hardswish, params, dst);
166175}
167176
168177void ggml_compute_forward_sqr (const ggml_compute_params * params, ggml_tensor * dst) {
169- unary_op<op_sqr>( params, dst);
178+ unary_op (op_sqr, params, dst);
170179}
171180
172181void ggml_compute_forward_sqrt (const ggml_compute_params * params, ggml_tensor * dst) {
173- unary_op<op_sqrt>( params, dst);
182+ unary_op (op_sqrt, params, dst);
174183}
175184
176185void ggml_compute_forward_sin (const ggml_compute_params * params, ggml_tensor * dst) {
177- unary_op<op_sin>( params, dst);
186+ unary_op (op_sin, params, dst);
178187}
179188
180189void ggml_compute_forward_cos (const ggml_compute_params * params, ggml_tensor * dst) {
181- unary_op<op_cos>( params, dst);
190+ unary_op (op_cos, params, dst);
182191}
183192
184193void ggml_compute_forward_log (const ggml_compute_params * params, ggml_tensor * dst) {
185- unary_op<op_log>(params, dst);
194+ unary_op (op_log, params, dst);
195+ }
196+
197+ static float softplus (float input, float beta=1 .0f , float threshold=20 .0f ) {
198+ if (input * beta > threshold) return input;
199+ return (1 /beta) * logf (1 + expf (beta * input));
186200}
201+
202+ void ggml_compute_forward_xielu (const ggml_compute_params * params, ggml_tensor * dst) {
203+ // Get the XIELU parameters from the operation
204+ const float * op_params = (const float *)dst->op_params ;
205+ float alpha_n = op_params[0 ];
206+ float alpha_p = op_params[1 ];
207+ const float beta = op_params[2 ];
208+ const float eps = op_params[3 ];
209+
210+ // alpha_p = softplus(alpha_p);
211+ // alpha_n = beta + softplus(alpha_n);
212+
213+ const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
214+ return op_xielu (f, alpha_n, alpha_p, beta, eps);
215+ };
216+
217+ unary_op (xielu_op_params, params, dst);
218+ }
219+
0 commit comments