Skip to content

Commit 52991e3

Browse files
authored
ggml : add more generic custom op, remove deprecated custom ops (#1183)
* ggml : add more generic ggml_custom op * ggml : remove deprecated custom ops
1 parent 1e965e8 commit 52991e3

File tree

7 files changed

+235
-516
lines changed

7 files changed

+235
-516
lines changed

include/ggml.h

Lines changed: 26 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -507,17 +507,12 @@ extern "C" {
507507

508508
GGML_OP_UNARY,
509509

510-
GGML_OP_MAP_UNARY,
511-
GGML_OP_MAP_BINARY,
512-
513-
GGML_OP_MAP_CUSTOM1_F32,
514-
GGML_OP_MAP_CUSTOM2_F32,
515-
GGML_OP_MAP_CUSTOM3_F32,
516-
517510
GGML_OP_MAP_CUSTOM1,
518511
GGML_OP_MAP_CUSTOM2,
519512
GGML_OP_MAP_CUSTOM3,
520513

514+
GGML_OP_CUSTOM,
515+
521516
GGML_OP_CROSS_ENTROPY_LOSS,
522517
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
523518
GGML_OP_OPT_STEP_ADAMW,
@@ -1916,83 +1911,6 @@ extern "C" {
19161911

19171912
// custom operators
19181913

1919-
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
1920-
typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
1921-
1922-
typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
1923-
typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
1924-
typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
1925-
1926-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
1927-
struct ggml_context * ctx,
1928-
struct ggml_tensor * a,
1929-
ggml_unary_op_f32_t fun),
1930-
"use ggml_map_custom1 instead");
1931-
1932-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
1933-
struct ggml_context * ctx,
1934-
struct ggml_tensor * a,
1935-
ggml_unary_op_f32_t fun),
1936-
"use ggml_map_custom1_inplace instead");
1937-
1938-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
1939-
struct ggml_context * ctx,
1940-
struct ggml_tensor * a,
1941-
struct ggml_tensor * b,
1942-
ggml_binary_op_f32_t fun),
1943-
"use ggml_map_custom2 instead");
1944-
1945-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
1946-
struct ggml_context * ctx,
1947-
struct ggml_tensor * a,
1948-
struct ggml_tensor * b,
1949-
ggml_binary_op_f32_t fun),
1950-
"use ggml_map_custom2_inplace instead");
1951-
1952-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
1953-
struct ggml_context * ctx,
1954-
struct ggml_tensor * a,
1955-
ggml_custom1_op_f32_t fun),
1956-
"use ggml_map_custom1 instead");
1957-
1958-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
1959-
struct ggml_context * ctx,
1960-
struct ggml_tensor * a,
1961-
ggml_custom1_op_f32_t fun),
1962-
"use ggml_map_custom1_inplace instead");
1963-
1964-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
1965-
struct ggml_context * ctx,
1966-
struct ggml_tensor * a,
1967-
struct ggml_tensor * b,
1968-
ggml_custom2_op_f32_t fun),
1969-
"use ggml_map_custom2 instead");
1970-
1971-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
1972-
struct ggml_context * ctx,
1973-
struct ggml_tensor * a,
1974-
struct ggml_tensor * b,
1975-
ggml_custom2_op_f32_t fun),
1976-
"use ggml_map_custom2_inplace instead");
1977-
1978-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
1979-
struct ggml_context * ctx,
1980-
struct ggml_tensor * a,
1981-
struct ggml_tensor * b,
1982-
struct ggml_tensor * c,
1983-
ggml_custom3_op_f32_t fun),
1984-
"use ggml_map_custom3 instead");
1985-
1986-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
1987-
struct ggml_context * ctx,
1988-
struct ggml_tensor * a,
1989-
struct ggml_tensor * b,
1990-
struct ggml_tensor * c,
1991-
ggml_custom3_op_f32_t fun),
1992-
"use ggml_map_custom3_inplace instead");
1993-
1994-
// custom operators v2
1995-
19961914
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
19971915
typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
19981916
typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
@@ -2048,6 +1966,30 @@ extern "C" {
20481966
int n_tasks,
20491967
void * userdata);
20501968

1969+
typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata);
1970+
1971+
GGML_API struct ggml_tensor * ggml_custom_4d(
1972+
struct ggml_context * ctx,
1973+
enum ggml_type type,
1974+
int64_t ne0,
1975+
int64_t ne1,
1976+
int64_t ne2,
1977+
int64_t ne3,
1978+
struct ggml_tensor ** args,
1979+
int n_args,
1980+
ggml_custom_op_t fun,
1981+
int n_tasks,
1982+
void * userdata);
1983+
1984+
GGML_API struct ggml_tensor * ggml_custom_inplace(
1985+
struct ggml_context * ctx,
1986+
struct ggml_tensor * a,
1987+
struct ggml_tensor ** args,
1988+
int n_args,
1989+
ggml_custom_op_t fun,
1990+
int n_tasks,
1991+
void * userdata);
1992+
20511993
// loss function
20521994

20531995
GGML_API struct ggml_tensor * ggml_cross_entropy_loss(

src/ggml-cpu/ggml-cpu.c

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,41 +2027,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20272027
{
20282028
ggml_compute_forward_rwkv_wkv7(params, tensor);
20292029
} break;
2030-
case GGML_OP_MAP_UNARY:
2031-
{
2032-
ggml_unary_op_f32_t fun;
2033-
memcpy(&fun, tensor->op_params, sizeof(fun));
2034-
ggml_compute_forward_map_unary(params, tensor, fun);
2035-
}
2036-
break;
2037-
case GGML_OP_MAP_BINARY:
2038-
{
2039-
ggml_binary_op_f32_t fun;
2040-
memcpy(&fun, tensor->op_params, sizeof(fun));
2041-
ggml_compute_forward_map_binary(params, tensor, fun);
2042-
}
2043-
break;
2044-
case GGML_OP_MAP_CUSTOM1_F32:
2045-
{
2046-
ggml_custom1_op_f32_t fun;
2047-
memcpy(&fun, tensor->op_params, sizeof(fun));
2048-
ggml_compute_forward_map_custom1_f32(params, tensor, fun);
2049-
}
2050-
break;
2051-
case GGML_OP_MAP_CUSTOM2_F32:
2052-
{
2053-
ggml_custom2_op_f32_t fun;
2054-
memcpy(&fun, tensor->op_params, sizeof(fun));
2055-
ggml_compute_forward_map_custom2_f32(params, tensor, fun);
2056-
}
2057-
break;
2058-
case GGML_OP_MAP_CUSTOM3_F32:
2059-
{
2060-
ggml_custom3_op_f32_t fun;
2061-
memcpy(&fun, tensor->op_params, sizeof(fun));
2062-
ggml_compute_forward_map_custom3_f32(params, tensor, fun);
2063-
}
2064-
break;
20652030
case GGML_OP_MAP_CUSTOM1:
20662031
{
20672032
ggml_compute_forward_map_custom1(params, tensor);
@@ -2077,6 +2042,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20772042
ggml_compute_forward_map_custom3(params, tensor);
20782043
}
20792044
break;
2045+
case GGML_OP_CUSTOM:
2046+
{
2047+
ggml_compute_forward_custom(params, tensor);
2048+
}
2049+
break;
20802050
case GGML_OP_CROSS_ENTROPY_LOSS:
20812051
{
20822052
ggml_compute_forward_cross_entropy_loss(params, tensor);
@@ -2328,11 +2298,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23282298
case GGML_OP_WIN_PART:
23292299
case GGML_OP_WIN_UNPART:
23302300
case GGML_OP_GET_REL_POS:
2331-
case GGML_OP_MAP_UNARY:
2332-
case GGML_OP_MAP_BINARY:
2333-
case GGML_OP_MAP_CUSTOM1_F32:
2334-
case GGML_OP_MAP_CUSTOM2_F32:
2335-
case GGML_OP_MAP_CUSTOM3_F32:
23362301
{
23372302
n_tasks = 1;
23382303
} break;
@@ -2366,6 +2331,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23662331
n_tasks = MIN(p.n_tasks, n_threads);
23672332
}
23682333
} break;
2334+
case GGML_OP_CUSTOM:
2335+
{
2336+
struct ggml_custom_op_params p;
2337+
memcpy(&p, node->op_params, sizeof(p));
2338+
if (p.n_tasks == GGML_N_TASKS_MAX) {
2339+
n_tasks = n_threads;
2340+
} else {
2341+
n_tasks = MIN(p.n_tasks, n_threads);
2342+
}
2343+
} break;
23692344
case GGML_OP_CROSS_ENTROPY_LOSS:
23702345
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23712346
case GGML_OP_OPT_STEP_ADAMW:

src/ggml-cpu/ops.cpp

Lines changed: 12 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -8264,152 +8264,6 @@ void ggml_compute_forward_rwkv_wkv7(
82648264
}
82658265
}
82668266

8267-
// ggml_compute_forward_map_unary
8268-
8269-
static void ggml_compute_forward_map_unary_f32(
8270-
const ggml_compute_params * params,
8271-
ggml_tensor * dst,
8272-
const ggml_unary_op_f32_t fun) {
8273-
8274-
const ggml_tensor * src0 = dst->src[0];
8275-
8276-
if (params->ith != 0) {
8277-
return;
8278-
}
8279-
8280-
assert(ggml_is_contiguous_1(src0));
8281-
assert(ggml_is_contiguous_1(dst));
8282-
assert(ggml_are_same_shape(src0, dst));
8283-
8284-
const int n = ggml_nrows(src0);
8285-
const int nc = src0->ne[0];
8286-
8287-
for (int i = 0; i < n; i++) {
8288-
fun(nc,
8289-
(float *) ((char *) dst->data + i*( dst->nb[1])),
8290-
(float *) ((char *) src0->data + i*(src0->nb[1])));
8291-
}
8292-
}
8293-
8294-
void ggml_compute_forward_map_unary(
8295-
const ggml_compute_params * params,
8296-
ggml_tensor * dst,
8297-
const ggml_unary_op_f32_t fun) {
8298-
8299-
const ggml_tensor * src0 = dst->src[0];
8300-
8301-
switch (src0->type) {
8302-
case GGML_TYPE_F32:
8303-
{
8304-
ggml_compute_forward_map_unary_f32(params, dst, fun);
8305-
} break;
8306-
default:
8307-
{
8308-
GGML_ABORT("fatal error");
8309-
}
8310-
}
8311-
}
8312-
8313-
// ggml_compute_forward_map_binary
8314-
8315-
static void ggml_compute_forward_map_binary_f32(
8316-
const ggml_compute_params * params,
8317-
ggml_tensor * dst,
8318-
const ggml_binary_op_f32_t fun) {
8319-
8320-
const ggml_tensor * src0 = dst->src[0];
8321-
const ggml_tensor * src1 = dst->src[1];
8322-
8323-
if (params->ith != 0) {
8324-
return;
8325-
}
8326-
8327-
assert(ggml_is_contiguous_1(src0));
8328-
assert(ggml_is_contiguous_1(src1));
8329-
assert(ggml_is_contiguous_1(dst));
8330-
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8331-
8332-
const int n = ggml_nrows(src0);
8333-
const int nc = src0->ne[0];
8334-
8335-
for (int i = 0; i < n; i++) {
8336-
fun(nc,
8337-
(float *) ((char *) dst->data + i*( dst->nb[1])),
8338-
(float *) ((char *) src0->data + i*(src0->nb[1])),
8339-
(float *) ((char *) src1->data + i*(src1->nb[1])));
8340-
}
8341-
}
8342-
8343-
void ggml_compute_forward_map_binary(
8344-
const ggml_compute_params * params,
8345-
ggml_tensor * dst,
8346-
const ggml_binary_op_f32_t fun) {
8347-
8348-
const ggml_tensor * src0 = dst->src[0];
8349-
8350-
switch (src0->type) {
8351-
case GGML_TYPE_F32:
8352-
{
8353-
ggml_compute_forward_map_binary_f32(params, dst, fun);
8354-
} break;
8355-
default:
8356-
{
8357-
GGML_ABORT("fatal error");
8358-
}
8359-
}
8360-
}
8361-
8362-
// ggml_compute_forward_map_custom1
8363-
8364-
void ggml_compute_forward_map_custom1_f32(
8365-
const ggml_compute_params * params,
8366-
ggml_tensor * dst,
8367-
const ggml_custom1_op_f32_t fun) {
8368-
8369-
const ggml_tensor * a = dst->src[0];
8370-
8371-
if (params->ith != 0) {
8372-
return;
8373-
}
8374-
8375-
fun(dst, a);
8376-
}
8377-
8378-
// ggml_compute_forward_map_custom2
8379-
8380-
void ggml_compute_forward_map_custom2_f32(
8381-
const ggml_compute_params * params,
8382-
ggml_tensor * dst,
8383-
const ggml_custom2_op_f32_t fun) {
8384-
8385-
const ggml_tensor * a = dst->src[0];
8386-
const ggml_tensor * b = dst->src[1];
8387-
8388-
if (params->ith != 0) {
8389-
return;
8390-
}
8391-
8392-
fun(dst, a, b);
8393-
}
8394-
8395-
// ggml_compute_forward_map_custom3
8396-
8397-
void ggml_compute_forward_map_custom3_f32(
8398-
const ggml_compute_params * params,
8399-
ggml_tensor * dst,
8400-
const ggml_custom3_op_f32_t fun) {
8401-
8402-
const ggml_tensor * a = dst->src[0];
8403-
const ggml_tensor * b = dst->src[1];
8404-
const ggml_tensor * c = dst->src[1];
8405-
8406-
if (params->ith != 0) {
8407-
return;
8408-
}
8409-
8410-
fun(dst, a, b, c);
8411-
}
8412-
84138267
// ggml_compute_forward_map_custom1
84148268

84158269
void ggml_compute_forward_map_custom1(
@@ -8455,6 +8309,18 @@ void ggml_compute_forward_map_custom3(
84558309
p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
84568310
}
84578311

8312+
// ggml_compute_forward_custom
8313+
8314+
void ggml_compute_forward_custom(
8315+
const struct ggml_compute_params * params,
8316+
struct ggml_tensor * dst) {
8317+
8318+
struct ggml_custom_op_params p;
8319+
memcpy(&p, dst->op_params, sizeof(p));
8320+
8321+
p.fun(dst, params->ith, params->nth, p.userdata);
8322+
}
8323+
84588324
// ggml_compute_forward_cross_entropy_loss
84598325

84608326
static void ggml_compute_forward_cross_entropy_loss_f32(

0 commit comments

Comments
 (0)