Skip to content

Commit 6ccc073

Browse files
Nexesenexikawrakow
andcommitted
Simdify and multi-thread tanh
picked from ik_llama.cpp, a llama_cpp fork maintained by Iwan Kawrakow Co-Authored-By: Kawrakow <[email protected]>
1 parent 2e012ca commit 6ccc073

File tree

1 file changed

+57
-6
lines changed

1 file changed

+57
-6
lines changed

ggml/src/ggml.c

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2578,7 +2578,7 @@ inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) {
25782578
inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
25792579
inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
25802580
inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
2581-
inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
2581+
//inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
25822582
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
25832583
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
25842584
inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
@@ -2703,6 +2703,13 @@ inline static float32x4_t ggml_v_silu(float32x4_t x) {
27032703
return vdivq_f32(x, one_plus_exp_neg_x);
27042704
}
27052705

2706+
inline static float32x4_t ggml_v_tanh(float32x4_t x) {
2707+
const float32x4_t one = vdupq_n_f32(1.0f);
2708+
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
2709+
const float32x4_t exp_two_x = ggml_v_expf(two_x);
2710+
return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
2711+
}
2712+
27062713
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
27072714

27082715
// adapted from arm limited optimized routine
@@ -2746,6 +2753,12 @@ inline static __m512 ggml_v_silu(__m512 x) {
27462753
return _mm512_div_ps(x, one_plus_exp_neg_x);
27472754
}
27482755

2756+
inline static __m512 ggml_v_tanh(__m512 x) {
2757+
const __m512 one = _mm512_set1_ps(1.0f);
2758+
const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f)));
2759+
return _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
2760+
}
2761+
27492762
#elif defined(__AVX2__) && defined(__FMA__)
27502763

27512764
// adapted from arm limited optimized routine
@@ -2801,6 +2814,12 @@ inline static __m256 ggml_v_silu(__m256 x) {
28012814
return _mm256_div_ps(x, one_plus_exp_neg_x);
28022815
}
28032816

2817+
inline static __m256 ggml_v_tanh(__m256 x) {
2818+
const __m256 one = _mm256_set1_ps(1.0f);
2819+
const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f)));
2820+
return _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
2821+
}
2822+
28042823
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
28052824

28062825
#if defined(__FMA__)
@@ -2855,6 +2874,12 @@ inline static __m128 ggml_v_silu(__m128 x) {
28552874
return _mm_div_ps(x, one_plus_exp_neg_x);
28562875
}
28572876

2877+
inline static __m128 ggml_v_tanh(__m128 x) {
2878+
const __m128 one = _mm_set1_ps(1.0f);
2879+
const __m128 exp_two_x = ggml_v_expf(_mm_mul_ps(x, _mm_set1_ps(2.f)));
2880+
return _mm_div_ps(_mm_sub_ps(exp_two_x, one), _mm_add_ps(exp_two_x, one));
2881+
}
2882+
28582883
#endif // __ARM_NEON / __AVX2__ / __SSE2__
28592884

28602885
static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
@@ -2890,6 +2915,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
28902915
}
28912916
}
28922917

2918+
static void ggml_vec_tanh_f32(const int n, float * y, const float * x) {
2919+
int i = 0;
2920+
#if defined(__AVX512F__) && defined(__AVX512DQ__)
2921+
for (; i + 15 < n; i += 16) {
2922+
_mm512_storeu_ps(y + i, ggml_v_tanh(_mm512_loadu_ps(x + i)));
2923+
}
2924+
#elif defined(__AVX2__) && defined(__FMA__)
2925+
for (; i + 7 < n; i += 8) {
2926+
_mm256_storeu_ps(y + i, ggml_v_tanh(_mm256_loadu_ps(x + i)));
2927+
}
2928+
#elif defined(__SSE2__)
2929+
for (; i + 3 < n; i += 4) {
2930+
_mm_storeu_ps(y + i, ggml_v_tanh(_mm_loadu_ps(x + i)));
2931+
}
2932+
#elif defined(__ARM_NEON) && defined(__aarch64__)
2933+
for (; i + 3 < n; i += 4) {
2934+
vst1q_f32(y + i, ggml_v_tanh(vld1q_f32(x + i)));
2935+
}
2936+
#endif
2937+
for (; i < n; ++i) {
2938+
y[i] = tanhf(x[i]);
2939+
}
2940+
}
2941+
28932942
static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
28942943
int i = 0;
28952944
ggml_float sum = 0;
@@ -11441,9 +11490,8 @@ static void ggml_compute_forward_tanh_f32(
1144111490

1144211491
const struct ggml_tensor * src0 = dst->src[0];
1144311492

11444-
if (params->ith != 0) {
11445-
return;
11446-
}
11493+
const int ith = params->ith;
11494+
const int nth = params->nth;
1144711495

1144811496
assert(ggml_is_contiguous_1(src0));
1144911497
assert(ggml_is_contiguous_1(dst));
@@ -11452,7 +11500,7 @@ static void ggml_compute_forward_tanh_f32(
1145211500
const int n = ggml_nrows(src0);
1145311501
const int nc = src0->ne[0];
1145411502

11455-
for (int i = 0; i < n; i++) {
11503+
for (int i = ith; i < n; i += nth) {
1145611504
ggml_vec_tanh_f32(nc,
1145711505
(float *) ((char *) dst->data + i*( dst->nb[1])),
1145811506
(float *) ((char *) src0->data + i*(src0->nb[1])));
@@ -19344,7 +19392,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1934419392
case GGML_UNARY_OP_SGN:
1934519393
case GGML_UNARY_OP_NEG:
1934619394
case GGML_UNARY_OP_STEP:
19347-
case GGML_UNARY_OP_TANH:
1934819395
case GGML_UNARY_OP_ELU:
1934919396
case GGML_UNARY_OP_RELU:
1935019397
case GGML_UNARY_OP_SIGMOID:
@@ -19361,6 +19408,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1936119408
{
1936219409
n_tasks = n_threads;
1936319410
} break;
19411+
case GGML_UNARY_OP_TANH:
19412+
{
19413+
n_tasks = MIN(ggml_nrows(node), n_threads);
19414+
} break;
1936419415
default:
1936519416
GGML_ABORT("fatal error");
1936619417
}

0 commit comments

Comments
 (0)