Skip to content

Commit 58e6e0f

Browse files
committed
Modify unary-ops.cpp to add the functor-based logic besides the template system to retain optimizations
1 parent 62401d8 commit 58e6e0f

File tree

1 file changed

+109
-30
lines changed

1 file changed

+109
-30
lines changed

ggml/src/ggml-cpu/unary-ops.cpp

Lines changed: 109 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,87 @@ static inline float op_log(float x) {
7373
return logf(x);
7474
}
7575

76+
template <float (*op)(float), typename src0_t, typename dst_t>
77+
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
78+
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
79+
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
80+
81+
for (int i = 0; i < n; i++) {
82+
y[i] = f32_to_dst(op(src0_to_f32(x[i])));
83+
}
84+
}
85+
86+
template <float (*op)(float), typename src0_t, typename dst_t>
87+
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
88+
const ggml_tensor * src0 = dst->src[0];
89+
90+
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
91+
92+
GGML_TENSOR_UNARY_OP_LOCALS
93+
94+
GGML_ASSERT( nb0 == sizeof(dst_t));
95+
GGML_ASSERT(nb00 == sizeof(src0_t));
96+
97+
const auto [ir0, ir1] = get_thread_range(params, src0);
98+
99+
for (int64_t ir = ir0; ir < ir1; ++ir) {
100+
const int64_t i03 = ir/(ne02*ne01);
101+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
102+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
103+
104+
dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
105+
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
106+
107+
vec_unary_op<op>(ne0, dst_ptr, src0_ptr);
108+
}
109+
}
110+
111+
// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
112+
template <float (*op)(float)>
113+
static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
114+
const ggml_tensor * src0 = dst->src[0];
115+
116+
/* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
117+
apply_unary_op<op, float, float>(params, dst);
118+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
119+
apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
120+
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
121+
apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
122+
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
123+
apply_unary_op<op, ggml_bf16_t, float>(params, dst);
124+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
125+
apply_unary_op<op, ggml_fp16_t, float>(params, dst);
126+
} else {
127+
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
128+
ggml_type_name(dst->type), ggml_type_name(src0->type));
129+
GGML_ABORT("fatal error");
130+
}
131+
}
132+
133+
template <float (*op)(float, ggml_tensor *)>
134+
static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) {
135+
const ggml_tensor * src0 = dst->src[0];
136+
137+
/* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
138+
apply_unary_op<op, float, float>(params, dst);
139+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
140+
apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
141+
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
142+
apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
143+
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
144+
apply_unary_op<op, ggml_bf16_t, float>(params, dst);
145+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
146+
apply_unary_op<op, ggml_fp16_t, float>(params, dst);
147+
} else {
148+
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
149+
ggml_type_name(dst->type), ggml_type_name(src0->type));
150+
GGML_ABORT("fatal error");
151+
}
152+
}
153+
154+
// Extend vec_unary_op to support functors
76155
template <typename Op, typename src0_t, typename dst_t>
77-
static inline void vec_unary_op(const Op & op, int64_t n, dst_t * y, const src0_t * x) {
156+
static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
78157
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
79158
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
80159

@@ -83,8 +162,9 @@ static inline void vec_unary_op(const Op & op, int64_t n, dst_t * y, const src0_
83162
}
84163
}
85164

165+
// Extend apply_unary_op to support functors
86166
template <typename Op, typename src0_t, typename dst_t>
87-
static void apply_unary_op(const Op& op, const ggml_compute_params * params, ggml_tensor * dst) {
167+
static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
88168
const ggml_tensor * src0 = dst->src[0];
89169

90170
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
@@ -104,25 +184,25 @@ static void apply_unary_op(const Op& op, const ggml_compute_params * params, ggm
104184
dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
105185
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
106186

107-
vec_unary_op<decltype(op), src0_t, dst_t>(op, ne0, dst_ptr, src0_ptr);
187+
vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
108188
}
109189
}
110190

111-
// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
191+
// Generic dispatcher for functors
112192
template <typename Op>
113-
static void unary_op(const Op& op, const ggml_compute_params * params, ggml_tensor * dst) {
193+
static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
114194
const ggml_tensor * src0 = dst->src[0];
115195

116196
/* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
117-
apply_unary_op<decltype(op), float, float>(op, params, dst);
197+
apply_unary_op_functor<Op, float, float>(params, dst, op);
118198
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
119-
apply_unary_op<decltype(op), ggml_fp16_t, ggml_fp16_t>(op, params, dst);
199+
apply_unary_op_functor<Op, ggml_fp16_t, ggml_fp16_t>(params, dst, op);
120200
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
121-
apply_unary_op<decltype(op), ggml_bf16_t, ggml_bf16_t>(op, params, dst);
201+
apply_unary_op_functor<Op, ggml_bf16_t, ggml_bf16_t>(params, dst, op);
122202
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
123-
apply_unary_op<decltype(op), ggml_bf16_t, float>(op, params, dst);
203+
apply_unary_op_functor<Op, ggml_bf16_t, float>(params, dst, op);
124204
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
125-
apply_unary_op<decltype(op), ggml_fp16_t, float>(op, params, dst);
205+
apply_unary_op_functor<Op, ggml_fp16_t, float>(params, dst, op);
126206
} else {
127207
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
128208
ggml_type_name(dst->type), ggml_type_name(src0->type));
@@ -131,80 +211,79 @@ static void unary_op(const Op& op, const ggml_compute_params * params, ggml_tens
131211
}
132212

133213
void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
134-
unary_op(op_abs, params, dst);
214+
unary_op<op_abs>(params, dst);
135215
}
136216

137217
void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {
138-
unary_op(op_sgn, params, dst);
218+
unary_op<op_sgn>(params, dst);
139219
}
140220

141221
void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) {
142-
unary_op(op_neg, params, dst);
222+
unary_op<op_neg>(params, dst);
143223
}
144224

145225
void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) {
146-
unary_op(op_step, params, dst);
226+
unary_op<op_step>(params, dst);
147227
}
148228

149229
void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) {
150-
unary_op(op_tanh, params, dst);
230+
unary_op<op_tanh>(params, dst);
151231
}
152232

153233
void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) {
154-
unary_op(op_elu, params, dst);
234+
unary_op<op_elu>(params, dst);
155235
}
156236

157237
void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) {
158-
unary_op(op_relu, params, dst);
238+
unary_op<op_relu>(params, dst);
159239
}
160240

161241
void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
162-
unary_op(op_sigmoid, params, dst);
242+
unary_op<op_sigmoid>(params, dst);
163243
}
164244

165245
void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
166-
unary_op(op_hardsigmoid, params, dst);
246+
unary_op<op_hardsigmoid>(params, dst);
167247
}
168248

169249
void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) {
170-
unary_op(op_exp, params, dst);
250+
unary_op<op_exp>(params, dst);
171251
}
172252

173253
void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) {
174-
unary_op(op_hardswish, params, dst);
254+
unary_op<op_hardswish>(params, dst);
175255
}
176256

177257
void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
178-
unary_op(op_sqr, params, dst);
258+
unary_op<op_sqr>(params, dst);
179259
}
180260

181261
void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) {
182-
unary_op(op_sqrt, params, dst);
262+
unary_op<op_sqrt>(params, dst);
183263
}
184264

185265
void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) {
186-
unary_op(op_sin, params, dst);
266+
unary_op<op_sin>(params, dst);
187267
}
188268

189269
void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) {
190-
unary_op(op_cos, params, dst);
270+
unary_op<op_cos>(params, dst);
191271
}
192272

193273
void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
194-
unary_op(op_log, params, dst);
274+
unary_op<op_log>(params, dst);
195275
}
196276

197277
void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
198-
// Get the XIELU parameters from the operation
199-
float alpha_n = ggml_get_op_params_f32(dst, 1);
200-
float alpha_p = ggml_get_op_params_f32(dst, 2);
278+
const float alpha_n = ggml_get_op_params_f32(dst, 1);
279+
const float alpha_p = ggml_get_op_params_f32(dst, 2);
201280
const float beta = ggml_get_op_params_f32(dst, 3);
202281
const float eps = ggml_get_op_params_f32(dst, 4);
203282

204283
const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
205284
return op_xielu(f, alpha_n, alpha_p, beta, eps);
206285
};
207286

208-
unary_op(xielu_op_params, params, dst);
287+
unary_op_functor(params, dst, xielu_op_params);
209288
}
210289

0 commit comments

Comments
 (0)