@@ -73,8 +73,87 @@ static inline float op_log(float x) {
7373 return logf (x);
7474}
7575
76+ template <float (*op)(float ), typename src0_t , typename dst_t >
77+ static inline void vec_unary_op (int64_t n, dst_t * y, const src0_t * x) {
78+ constexpr auto src0_to_f32 = type_conversion_table<src0_t >::to_f32;
79+ constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
80+
81+ for (int i = 0 ; i < n; i++) {
82+ y[i] = f32_to_dst (op (src0_to_f32 (x[i])));
83+ }
84+ }
85+
86+ template <float (*op)(float ), typename src0_t , typename dst_t >
87+ static void apply_unary_op (const ggml_compute_params * params, ggml_tensor * dst) {
88+ const ggml_tensor * src0 = dst->src [0 ];
89+
90+ GGML_ASSERT (ggml_is_contiguous_1 (src0) && ggml_is_contiguous_1 (dst) && ggml_are_same_shape (src0, dst));
91+
92+ GGML_TENSOR_UNARY_OP_LOCALS
93+
94+ GGML_ASSERT ( nb0 == sizeof (dst_t ));
95+ GGML_ASSERT (nb00 == sizeof (src0_t ));
96+
97+ const auto [ir0, ir1] = get_thread_range (params, src0);
98+
99+ for (int64_t ir = ir0; ir < ir1; ++ir) {
100+ const int64_t i03 = ir/(ne02*ne01);
101+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
102+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
103+
104+ dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
105+ const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
106+
107+ vec_unary_op<op>(ne0, dst_ptr, src0_ptr);
108+ }
109+ }
110+
111+ // TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
112+ template <float (*op)(float )>
113+ static void unary_op (const ggml_compute_params * params, ggml_tensor * dst) {
114+ const ggml_tensor * src0 = dst->src [0 ];
115+
116+ /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
117+ apply_unary_op<op, float , float >(params, dst);
118+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
119+ apply_unary_op<op, ggml_fp16_t , ggml_fp16_t >(params, dst);
120+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
121+ apply_unary_op<op, ggml_bf16_t , ggml_bf16_t >(params, dst);
122+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
123+ apply_unary_op<op, ggml_bf16_t , float >(params, dst);
124+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
125+ apply_unary_op<op, ggml_fp16_t , float >(params, dst);
126+ } else {
127+ fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s\n " , __func__,
128+ ggml_type_name (dst->type ), ggml_type_name (src0->type ));
129+ GGML_ABORT (" fatal error" );
130+ }
131+ }
132+
133+ template <float (*op)(float , ggml_tensor *)>
134+ static void unary_op_params (const ggml_compute_params * params, ggml_tensor * dst) {
135+ const ggml_tensor * src0 = dst->src [0 ];
136+
137+ /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
138+ apply_unary_op<op, float , float >(params, dst);
139+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
140+ apply_unary_op<op, ggml_fp16_t , ggml_fp16_t >(params, dst);
141+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
142+ apply_unary_op<op, ggml_bf16_t , ggml_bf16_t >(params, dst);
143+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
144+ apply_unary_op<op, ggml_bf16_t , float >(params, dst);
145+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
146+ apply_unary_op<op, ggml_fp16_t , float >(params, dst);
147+ } else {
148+ fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s\n " , __func__,
149+ ggml_type_name (dst->type ), ggml_type_name (src0->type ));
150+ GGML_ABORT (" fatal error" );
151+ }
152+ }
153+
154+ // Extend vec_unary_op to support functors
76155template <typename Op, typename src0_t , typename dst_t >
77- static inline void vec_unary_op ( const Op & op, int64_t n, dst_t * y, const src0_t * x) {
156+ static inline void vec_unary_op_functor ( int64_t n, dst_t * y, const src0_t * x, Op op ) {
78157 constexpr auto src0_to_f32 = type_conversion_table<src0_t >::to_f32;
79158 constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
80159
@@ -83,8 +162,9 @@ static inline void vec_unary_op(const Op & op, int64_t n, dst_t * y, const src0_
83162 }
84163}
85164
165+ // Extend apply_unary_op to support functors
86166template <typename Op, typename src0_t , typename dst_t >
87- static void apply_unary_op (const Op& op, const ggml_compute_params * params, ggml_tensor * dst) {
167+ static void apply_unary_op_functor (const ggml_compute_params * params, ggml_tensor * dst, Op op ) {
88168 const ggml_tensor * src0 = dst->src [0 ];
89169
90170 GGML_ASSERT (ggml_is_contiguous_1 (src0) && ggml_is_contiguous_1 (dst) && ggml_are_same_shape (src0, dst));
@@ -104,25 +184,25 @@ static void apply_unary_op(const Op& op, const ggml_compute_params * params, ggm
104184 dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
105185 const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
106186
107- vec_unary_op< decltype (op), src0_t , dst_t >(op, ne0, dst_ptr, src0_ptr);
187+ vec_unary_op_functor ( ne0, dst_ptr, src0_ptr, op );
108188 }
109189}
110190
111- // TODO: Use the 'traits' lookup table ( for type conversion fns), instead of a mass of 'if' conditions with long templates
191+ // Generic dispatcher for functors
112192template <typename Op>
113- static void unary_op (const Op& op, const ggml_compute_params * params, ggml_tensor * dst) {
193+ static void unary_op_functor (const ggml_compute_params * params, ggml_tensor * dst, Op op ) {
114194 const ggml_tensor * src0 = dst->src [0 ];
115195
116196 /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
117- apply_unary_op< decltype (op) , float , float >(op, params, dst);
197+ apply_unary_op_functor<Op , float , float >(params, dst, op );
118198 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
119- apply_unary_op< decltype (op) , ggml_fp16_t , ggml_fp16_t >(op, params, dst);
199+ apply_unary_op_functor<Op , ggml_fp16_t , ggml_fp16_t >(params, dst, op );
120200 } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
121- apply_unary_op< decltype (op) , ggml_bf16_t , ggml_bf16_t >(op, params, dst);
201+ apply_unary_op_functor<Op , ggml_bf16_t , ggml_bf16_t >(params, dst, op );
122202 } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
123- apply_unary_op< decltype (op) , ggml_bf16_t , float >(op, params, dst);
203+ apply_unary_op_functor<Op , ggml_bf16_t , float >(params, dst, op );
124204 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
125- apply_unary_op< decltype (op) , ggml_fp16_t , float >(op, params, dst);
205+ apply_unary_op_functor<Op , ggml_fp16_t , float >(params, dst, op );
126206 } else {
127207 fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s\n " , __func__,
128208 ggml_type_name (dst->type ), ggml_type_name (src0->type ));
@@ -131,80 +211,79 @@ static void unary_op(const Op& op, const ggml_compute_params * params, ggml_tens
131211}
132212
133213void ggml_compute_forward_abs (const ggml_compute_params * params, ggml_tensor * dst) {
134- unary_op ( op_abs, params, dst);
214+ unary_op< op_abs>( params, dst);
135215}
136216
137217void ggml_compute_forward_sgn (const ggml_compute_params * params, ggml_tensor * dst) {
138- unary_op ( op_sgn, params, dst);
218+ unary_op< op_sgn>( params, dst);
139219}
140220
141221void ggml_compute_forward_neg (const ggml_compute_params * params, ggml_tensor * dst) {
142- unary_op ( op_neg, params, dst);
222+ unary_op< op_neg>( params, dst);
143223}
144224
145225void ggml_compute_forward_step (const ggml_compute_params * params, ggml_tensor * dst) {
146- unary_op ( op_step, params, dst);
226+ unary_op< op_step>( params, dst);
147227}
148228
149229void ggml_compute_forward_tanh (const ggml_compute_params * params, ggml_tensor * dst) {
150- unary_op ( op_tanh, params, dst);
230+ unary_op< op_tanh>( params, dst);
151231}
152232
153233void ggml_compute_forward_elu (const ggml_compute_params * params, ggml_tensor * dst) {
154- unary_op ( op_elu, params, dst);
234+ unary_op< op_elu>( params, dst);
155235}
156236
157237void ggml_compute_forward_relu (const ggml_compute_params * params, ggml_tensor * dst) {
158- unary_op ( op_relu, params, dst);
238+ unary_op< op_relu>( params, dst);
159239}
160240
161241void ggml_compute_forward_sigmoid (const ggml_compute_params * params, ggml_tensor * dst) {
162- unary_op ( op_sigmoid, params, dst);
242+ unary_op< op_sigmoid>( params, dst);
163243}
164244
165245void ggml_compute_forward_hardsigmoid (const ggml_compute_params * params, ggml_tensor * dst) {
166- unary_op ( op_hardsigmoid, params, dst);
246+ unary_op< op_hardsigmoid>( params, dst);
167247}
168248
169249void ggml_compute_forward_exp (const ggml_compute_params * params, ggml_tensor * dst) {
170- unary_op ( op_exp, params, dst);
250+ unary_op< op_exp>( params, dst);
171251}
172252
173253void ggml_compute_forward_hardswish (const ggml_compute_params * params, ggml_tensor * dst) {
174- unary_op ( op_hardswish, params, dst);
254+ unary_op< op_hardswish>( params, dst);
175255}
176256
177257void ggml_compute_forward_sqr (const ggml_compute_params * params, ggml_tensor * dst) {
178- unary_op ( op_sqr, params, dst);
258+ unary_op< op_sqr>( params, dst);
179259}
180260
181261void ggml_compute_forward_sqrt (const ggml_compute_params * params, ggml_tensor * dst) {
182- unary_op ( op_sqrt, params, dst);
262+ unary_op< op_sqrt>( params, dst);
183263}
184264
185265void ggml_compute_forward_sin (const ggml_compute_params * params, ggml_tensor * dst) {
186- unary_op ( op_sin, params, dst);
266+ unary_op< op_sin>( params, dst);
187267}
188268
189269void ggml_compute_forward_cos (const ggml_compute_params * params, ggml_tensor * dst) {
190- unary_op ( op_cos, params, dst);
270+ unary_op< op_cos>( params, dst);
191271}
192272
193273void ggml_compute_forward_log (const ggml_compute_params * params, ggml_tensor * dst) {
194- unary_op ( op_log, params, dst);
274+ unary_op< op_log>( params, dst);
195275}
196276
197277void ggml_compute_forward_xielu (const ggml_compute_params * params, ggml_tensor * dst) {
198- // Get the XIELU parameters from the operation
199- float alpha_n = ggml_get_op_params_f32 (dst, 1 );
200- float alpha_p = ggml_get_op_params_f32 (dst, 2 );
278+ const float alpha_n = ggml_get_op_params_f32 (dst, 1 );
279+ const float alpha_p = ggml_get_op_params_f32 (dst, 2 );
201280 const float beta = ggml_get_op_params_f32 (dst, 3 );
202281 const float eps = ggml_get_op_params_f32 (dst, 4 );
203282
204283 const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
205284 return op_xielu (f, alpha_n, alpha_p, beta, eps);
206285 };
207286
208- unary_op (xielu_op_params, params, dst);
287+ unary_op_functor ( params, dst, xielu_op_params );
209288}
210289
0 commit comments