Skip to content

Commit 4294dbf

Browse files
committed
Make xIELU an UNARY_OP
1 parent 2f68c03 commit 4294dbf

File tree

6 files changed

+25
-23
lines changed

6 files changed

+25
-23
lines changed

ggml/include/ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,6 @@ extern "C" {
554554
GGML_OP_OPT_STEP_SGD,
555555

556556
GGML_OP_GLU,
557-
GGML_OP_XIELU,
558557

559558
GGML_OP_COUNT,
560559
};
@@ -575,6 +574,7 @@ extern "C" {
575574
GGML_UNARY_OP_HARDSIGMOID,
576575
GGML_UNARY_OP_EXP,
577576
GGML_UNARY_OP_GELU_ERF,
577+
GGML_UNARY_OP_XIELU,
578578

579579
GGML_UNARY_OP_COUNT,
580580
};

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,10 +1974,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19741974
{
19751975
ggml_compute_forward_unary(params, tensor);
19761976
} break;
1977-
case GGML_OP_XIELU:
1978-
{
1979-
ggml_compute_forward_xielu(params, tensor);
1980-
} break;
19811977
case GGML_OP_GLU:
19821978
{
19831979
ggml_compute_forward_glu(params, tensor);
@@ -2144,7 +2140,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21442140
case GGML_OP_ADD_ID:
21452141
case GGML_OP_ADD1:
21462142
case GGML_OP_ACC:
2147-
case GGML_OP_XIELU:
21482143
{
21492144
n_tasks = n_threads;
21502145
} break;
@@ -2192,6 +2187,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21922187
case GGML_UNARY_OP_GELU_ERF:
21932188
case GGML_UNARY_OP_GELU_QUICK:
21942189
case GGML_UNARY_OP_SILU:
2190+
case GGML_UNARY_OP_XIELU:
21952191
{
21962192
n_tasks = n_threads;
21972193
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9767,6 +9767,10 @@ void ggml_compute_forward_unary(
97679767
{
97689768
ggml_compute_forward_exp(params, dst);
97699769
} break;
9770+
case GGML_UNARY_OP_XIELU:
9771+
{
9772+
ggml_compute_forward_xielu(params, dst);
9773+
} break;
97709774
default:
97719775
{
97729776
GGML_ABORT("fatal error");

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2333,6 +2333,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23332333
case GGML_UNARY_OP_ELU:
23342334
ggml_cuda_op_elu(ctx, dst);
23352335
break;
2336+
case GGML_UNARY_OP_XIELU:
2337+
ggml_cuda_op_xielu(ctx, dst);
2338+
break;
23362339
default:
23372340
return false;
23382341
}
@@ -2517,9 +2520,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
25172520
case GGML_OP_OPT_STEP_SGD:
25182521
ggml_cuda_opt_step_sgd(ctx, dst);
25192522
break;
2520-
case GGML_OP_XIELU:
2521-
ggml_cuda_op_xielu(ctx, dst);
2522-
break;
25232523
default:
25242524
return false;
25252525
}

ggml/src/ggml-cuda/unary.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,11 +407,12 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
407407
// xIELU
408408
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
409409
// Get the XIELU parameters from the operation
410+
410411
const float * op_params = (const float*)dst->op_params;
411-
float alpha_n = op_params[0];
412-
float alpha_p = op_params[1];
413-
const float beta = op_params[2];
414-
const float eps = op_params[3];
412+
float alpha_n = ggml_get_op_params_f32(dst, 1);
413+
float alpha_p = ggml_get_op_params_f32(dst, 2);
414+
const float beta = ggml_get_op_params_f32(dst, 3);
415+
const float eps = ggml_get_op_params_f32(dst, 4);
415416

416417
op_xielu_functor op(alpha_n, alpha_p, beta, eps);
417418
ggml_cuda_op_unary(ctx, dst, op);

ggml/src/ggml.c

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,10 +1017,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10171017
"OPT_STEP_SGD",
10181018

10191019
"GLU",
1020-
"XIELU",
10211020
};
10221021

1023-
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 90");
1022+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
10241023

10251024
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10261025
"none",
@@ -1122,10 +1121,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11221121
"sgd(x)",
11231122

11241123
"glu(x)",
1125-
"xielu(x)",
11261124
};
11271125

1128-
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 90");
1126+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
11291127

11301128
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11311129

@@ -1145,9 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
11451143
"HARDSIGMOID",
11461144
"EXP",
11471145
"GELU_ERF",
1146+
"XIELU",
11481147
};
11491148

1150-
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1149+
static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16");
11511150

11521151
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
11531152
"REGLU",
@@ -2662,11 +2661,13 @@ struct ggml_tensor * ggml_xielu(
26622661
float eps) {
26632662
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
26642663

2665-
// Store the parameters as operation parameters
2666-
float params[] = { beta + softplus(alpha_n), softplus(alpha_p), beta, eps };
2667-
ggml_set_op_params(result, params, sizeof(params));
2664+
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
2665+
ggml_set_op_params_f32(result, 1, beta + softplus(alpha_n));
2666+
ggml_set_op_params_f32(result, 2, softplus(alpha_p));
2667+
ggml_set_op_params_f32(result, 3, beta);
2668+
ggml_set_op_params_f32(result, 4, eps);
26682669

2669-
result->op = GGML_OP_XIELU;
2670+
result->op = GGML_OP_UNARY;
26702671
result->src[0] = a;
26712672

26722673
return result;
@@ -7236,4 +7237,4 @@ bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, cons
72367237
if (p0->poll != p1->poll ) return false;
72377238
if (p0->strict_cpu != p1->strict_cpu ) return false;
72387239
return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
7239-
}
7240+
}

0 commit comments

Comments
 (0)