Skip to content

Commit 459895c

Browse files
slarenggerganov
authored andcommitted
ggml : add more generic custom op, remove deprecated custom ops (ggml/1183)
* ggml : add more generic ggml_custom op * ggml : remove deprecated custom ops
1 parent e4bf72d commit 459895c

File tree

6 files changed

+132
-485
lines changed

6 files changed

+132
-485
lines changed

ggml/include/ggml.h

Lines changed: 26 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -507,17 +507,12 @@ extern "C" {
507507

508508
GGML_OP_UNARY,
509509

510-
GGML_OP_MAP_UNARY,
511-
GGML_OP_MAP_BINARY,
512-
513-
GGML_OP_MAP_CUSTOM1_F32,
514-
GGML_OP_MAP_CUSTOM2_F32,
515-
GGML_OP_MAP_CUSTOM3_F32,
516-
517510
GGML_OP_MAP_CUSTOM1,
518511
GGML_OP_MAP_CUSTOM2,
519512
GGML_OP_MAP_CUSTOM3,
520513

514+
GGML_OP_CUSTOM,
515+
521516
GGML_OP_CROSS_ENTROPY_LOSS,
522517
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
523518
GGML_OP_OPT_STEP_ADAMW,
@@ -1916,83 +1911,6 @@ extern "C" {
19161911

19171912
// custom operators
19181913

1919-
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
1920-
typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
1921-
1922-
typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
1923-
typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
1924-
typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
1925-
1926-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
1927-
struct ggml_context * ctx,
1928-
struct ggml_tensor * a,
1929-
ggml_unary_op_f32_t fun),
1930-
"use ggml_map_custom1 instead");
1931-
1932-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
1933-
struct ggml_context * ctx,
1934-
struct ggml_tensor * a,
1935-
ggml_unary_op_f32_t fun),
1936-
"use ggml_map_custom1_inplace instead");
1937-
1938-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
1939-
struct ggml_context * ctx,
1940-
struct ggml_tensor * a,
1941-
struct ggml_tensor * b,
1942-
ggml_binary_op_f32_t fun),
1943-
"use ggml_map_custom2 instead");
1944-
1945-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
1946-
struct ggml_context * ctx,
1947-
struct ggml_tensor * a,
1948-
struct ggml_tensor * b,
1949-
ggml_binary_op_f32_t fun),
1950-
"use ggml_map_custom2_inplace instead");
1951-
1952-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
1953-
struct ggml_context * ctx,
1954-
struct ggml_tensor * a,
1955-
ggml_custom1_op_f32_t fun),
1956-
"use ggml_map_custom1 instead");
1957-
1958-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
1959-
struct ggml_context * ctx,
1960-
struct ggml_tensor * a,
1961-
ggml_custom1_op_f32_t fun),
1962-
"use ggml_map_custom1_inplace instead");
1963-
1964-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
1965-
struct ggml_context * ctx,
1966-
struct ggml_tensor * a,
1967-
struct ggml_tensor * b,
1968-
ggml_custom2_op_f32_t fun),
1969-
"use ggml_map_custom2 instead");
1970-
1971-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
1972-
struct ggml_context * ctx,
1973-
struct ggml_tensor * a,
1974-
struct ggml_tensor * b,
1975-
ggml_custom2_op_f32_t fun),
1976-
"use ggml_map_custom2_inplace instead");
1977-
1978-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
1979-
struct ggml_context * ctx,
1980-
struct ggml_tensor * a,
1981-
struct ggml_tensor * b,
1982-
struct ggml_tensor * c,
1983-
ggml_custom3_op_f32_t fun),
1984-
"use ggml_map_custom3 instead");
1985-
1986-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
1987-
struct ggml_context * ctx,
1988-
struct ggml_tensor * a,
1989-
struct ggml_tensor * b,
1990-
struct ggml_tensor * c,
1991-
ggml_custom3_op_f32_t fun),
1992-
"use ggml_map_custom3_inplace instead");
1993-
1994-
// custom operators v2
1995-
19961914
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
19971915
typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
19981916
typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
@@ -2048,6 +1966,30 @@ extern "C" {
20481966
int n_tasks,
20491967
void * userdata);
20501968

1969+
typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata);
1970+
1971+
GGML_API struct ggml_tensor * ggml_custom_4d(
1972+
struct ggml_context * ctx,
1973+
enum ggml_type type,
1974+
int64_t ne0,
1975+
int64_t ne1,
1976+
int64_t ne2,
1977+
int64_t ne3,
1978+
struct ggml_tensor ** args,
1979+
int n_args,
1980+
ggml_custom_op_t fun,
1981+
int n_tasks,
1982+
void * userdata);
1983+
1984+
GGML_API struct ggml_tensor * ggml_custom_inplace(
1985+
struct ggml_context * ctx,
1986+
struct ggml_tensor * a,
1987+
struct ggml_tensor ** args,
1988+
int n_args,
1989+
ggml_custom_op_t fun,
1990+
int n_tasks,
1991+
void * userdata);
1992+
20511993
// loss function
20521994

20531995
GGML_API struct ggml_tensor * ggml_cross_entropy_loss(

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,41 +2027,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20272027
{
20282028
ggml_compute_forward_rwkv_wkv7(params, tensor);
20292029
} break;
2030-
case GGML_OP_MAP_UNARY:
2031-
{
2032-
ggml_unary_op_f32_t fun;
2033-
memcpy(&fun, tensor->op_params, sizeof(fun));
2034-
ggml_compute_forward_map_unary(params, tensor, fun);
2035-
}
2036-
break;
2037-
case GGML_OP_MAP_BINARY:
2038-
{
2039-
ggml_binary_op_f32_t fun;
2040-
memcpy(&fun, tensor->op_params, sizeof(fun));
2041-
ggml_compute_forward_map_binary(params, tensor, fun);
2042-
}
2043-
break;
2044-
case GGML_OP_MAP_CUSTOM1_F32:
2045-
{
2046-
ggml_custom1_op_f32_t fun;
2047-
memcpy(&fun, tensor->op_params, sizeof(fun));
2048-
ggml_compute_forward_map_custom1_f32(params, tensor, fun);
2049-
}
2050-
break;
2051-
case GGML_OP_MAP_CUSTOM2_F32:
2052-
{
2053-
ggml_custom2_op_f32_t fun;
2054-
memcpy(&fun, tensor->op_params, sizeof(fun));
2055-
ggml_compute_forward_map_custom2_f32(params, tensor, fun);
2056-
}
2057-
break;
2058-
case GGML_OP_MAP_CUSTOM3_F32:
2059-
{
2060-
ggml_custom3_op_f32_t fun;
2061-
memcpy(&fun, tensor->op_params, sizeof(fun));
2062-
ggml_compute_forward_map_custom3_f32(params, tensor, fun);
2063-
}
2064-
break;
20652030
case GGML_OP_MAP_CUSTOM1:
20662031
{
20672032
ggml_compute_forward_map_custom1(params, tensor);
@@ -2077,6 +2042,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20772042
ggml_compute_forward_map_custom3(params, tensor);
20782043
}
20792044
break;
2045+
case GGML_OP_CUSTOM:
2046+
{
2047+
ggml_compute_forward_custom(params, tensor);
2048+
}
2049+
break;
20802050
case GGML_OP_CROSS_ENTROPY_LOSS:
20812051
{
20822052
ggml_compute_forward_cross_entropy_loss(params, tensor);
@@ -2328,11 +2298,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23282298
case GGML_OP_WIN_PART:
23292299
case GGML_OP_WIN_UNPART:
23302300
case GGML_OP_GET_REL_POS:
2331-
case GGML_OP_MAP_UNARY:
2332-
case GGML_OP_MAP_BINARY:
2333-
case GGML_OP_MAP_CUSTOM1_F32:
2334-
case GGML_OP_MAP_CUSTOM2_F32:
2335-
case GGML_OP_MAP_CUSTOM3_F32:
23362301
{
23372302
n_tasks = 1;
23382303
} break;
@@ -2366,6 +2331,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23662331
n_tasks = MIN(p.n_tasks, n_threads);
23672332
}
23682333
} break;
2334+
case GGML_OP_CUSTOM:
2335+
{
2336+
struct ggml_custom_op_params p;
2337+
memcpy(&p, node->op_params, sizeof(p));
2338+
if (p.n_tasks == GGML_N_TASKS_MAX) {
2339+
n_tasks = n_threads;
2340+
} else {
2341+
n_tasks = MIN(p.n_tasks, n_threads);
2342+
}
2343+
} break;
23692344
case GGML_OP_CROSS_ENTROPY_LOSS:
23702345
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23712346
case GGML_OP_OPT_STEP_ADAMW:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 12 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -8268,152 +8268,6 @@ void ggml_compute_forward_rwkv_wkv7(
82688268
}
82698269
}
82708270

8271-
// ggml_compute_forward_map_unary
8272-
8273-
static void ggml_compute_forward_map_unary_f32(
8274-
const ggml_compute_params * params,
8275-
ggml_tensor * dst,
8276-
const ggml_unary_op_f32_t fun) {
8277-
8278-
const ggml_tensor * src0 = dst->src[0];
8279-
8280-
if (params->ith != 0) {
8281-
return;
8282-
}
8283-
8284-
assert(ggml_is_contiguous_1(src0));
8285-
assert(ggml_is_contiguous_1(dst));
8286-
assert(ggml_are_same_shape(src0, dst));
8287-
8288-
const int n = ggml_nrows(src0);
8289-
const int nc = src0->ne[0];
8290-
8291-
for (int i = 0; i < n; i++) {
8292-
fun(nc,
8293-
(float *) ((char *) dst->data + i*( dst->nb[1])),
8294-
(float *) ((char *) src0->data + i*(src0->nb[1])));
8295-
}
8296-
}
8297-
8298-
void ggml_compute_forward_map_unary(
8299-
const ggml_compute_params * params,
8300-
ggml_tensor * dst,
8301-
const ggml_unary_op_f32_t fun) {
8302-
8303-
const ggml_tensor * src0 = dst->src[0];
8304-
8305-
switch (src0->type) {
8306-
case GGML_TYPE_F32:
8307-
{
8308-
ggml_compute_forward_map_unary_f32(params, dst, fun);
8309-
} break;
8310-
default:
8311-
{
8312-
GGML_ABORT("fatal error");
8313-
}
8314-
}
8315-
}
8316-
8317-
// ggml_compute_forward_map_binary
8318-
8319-
static void ggml_compute_forward_map_binary_f32(
8320-
const ggml_compute_params * params,
8321-
ggml_tensor * dst,
8322-
const ggml_binary_op_f32_t fun) {
8323-
8324-
const ggml_tensor * src0 = dst->src[0];
8325-
const ggml_tensor * src1 = dst->src[1];
8326-
8327-
if (params->ith != 0) {
8328-
return;
8329-
}
8330-
8331-
assert(ggml_is_contiguous_1(src0));
8332-
assert(ggml_is_contiguous_1(src1));
8333-
assert(ggml_is_contiguous_1(dst));
8334-
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8335-
8336-
const int n = ggml_nrows(src0);
8337-
const int nc = src0->ne[0];
8338-
8339-
for (int i = 0; i < n; i++) {
8340-
fun(nc,
8341-
(float *) ((char *) dst->data + i*( dst->nb[1])),
8342-
(float *) ((char *) src0->data + i*(src0->nb[1])),
8343-
(float *) ((char *) src1->data + i*(src1->nb[1])));
8344-
}
8345-
}
8346-
8347-
void ggml_compute_forward_map_binary(
8348-
const ggml_compute_params * params,
8349-
ggml_tensor * dst,
8350-
const ggml_binary_op_f32_t fun) {
8351-
8352-
const ggml_tensor * src0 = dst->src[0];
8353-
8354-
switch (src0->type) {
8355-
case GGML_TYPE_F32:
8356-
{
8357-
ggml_compute_forward_map_binary_f32(params, dst, fun);
8358-
} break;
8359-
default:
8360-
{
8361-
GGML_ABORT("fatal error");
8362-
}
8363-
}
8364-
}
8365-
8366-
// ggml_compute_forward_map_custom1
8367-
8368-
void ggml_compute_forward_map_custom1_f32(
8369-
const ggml_compute_params * params,
8370-
ggml_tensor * dst,
8371-
const ggml_custom1_op_f32_t fun) {
8372-
8373-
const ggml_tensor * a = dst->src[0];
8374-
8375-
if (params->ith != 0) {
8376-
return;
8377-
}
8378-
8379-
fun(dst, a);
8380-
}
8381-
8382-
// ggml_compute_forward_map_custom2
8383-
8384-
void ggml_compute_forward_map_custom2_f32(
8385-
const ggml_compute_params * params,
8386-
ggml_tensor * dst,
8387-
const ggml_custom2_op_f32_t fun) {
8388-
8389-
const ggml_tensor * a = dst->src[0];
8390-
const ggml_tensor * b = dst->src[1];
8391-
8392-
if (params->ith != 0) {
8393-
return;
8394-
}
8395-
8396-
fun(dst, a, b);
8397-
}
8398-
8399-
// ggml_compute_forward_map_custom3
8400-
8401-
void ggml_compute_forward_map_custom3_f32(
8402-
const ggml_compute_params * params,
8403-
ggml_tensor * dst,
8404-
const ggml_custom3_op_f32_t fun) {
8405-
8406-
const ggml_tensor * a = dst->src[0];
8407-
const ggml_tensor * b = dst->src[1];
8408-
const ggml_tensor * c = dst->src[1];
8409-
8410-
if (params->ith != 0) {
8411-
return;
8412-
}
8413-
8414-
fun(dst, a, b, c);
8415-
}
8416-
84178271
// ggml_compute_forward_map_custom1
84188272

84198273
void ggml_compute_forward_map_custom1(
@@ -8459,6 +8313,18 @@ void ggml_compute_forward_map_custom3(
84598313
p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
84608314
}
84618315

8316+
// ggml_compute_forward_custom
8317+
8318+
void ggml_compute_forward_custom(
8319+
const struct ggml_compute_params * params,
8320+
struct ggml_tensor * dst) {
8321+
8322+
struct ggml_custom_op_params p;
8323+
memcpy(&p, dst->op_params, sizeof(p));
8324+
8325+
p.fun(dst, params->ith, params->nth, p.userdata);
8326+
}
8327+
84628328
// ggml_compute_forward_cross_entropy_loss
84638329

84648330
static void ggml_compute_forward_cross_entropy_loss_f32(

0 commit comments

Comments
 (0)