Skip to content

Commit 3ef6c8c

Browse files
authored
ggml : add ggml_add_id (#13)
* ggml : add ggml_add_id * add cuda impl * llama : add weight support check for add_id * perf opt * add vulkan impl * rename cuda files * add metal impl * allow in-place ggml_add_id * llama : keep biases on CPU with --cpu-moe
1 parent ec95c0e commit 3ef6c8c

File tree

21 files changed

+473
-25
lines changed

21 files changed

+473
-25
lines changed

common/arg.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2384,9 +2384,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
23842384
{"--cpu-moe"},
23852385
"use CPU for Mixture of Experts (MoE) weights",
23862386
[](common_params & params) {
2387-
params.tensor_buft_overrides.push_back({"\\.ffn_up_exps\\.weight$", ggml_backend_cpu_buffer_type()});
2388-
params.tensor_buft_overrides.push_back({"\\.ffn_down_exps\\.weight$", ggml_backend_cpu_buffer_type()});
2389-
params.tensor_buft_overrides.push_back({"\\.ffn_gate_exps\\.weight$", ggml_backend_cpu_buffer_type()});
2387+
params.tensor_buft_overrides.push_back({"\\.ffn_up_exps\\.", ggml_backend_cpu_buffer_type()});
2388+
params.tensor_buft_overrides.push_back({"\\.ffn_down_exps\\.", ggml_backend_cpu_buffer_type()});
2389+
params.tensor_buft_overrides.push_back({"\\.ffn_gate_exps\\.", ggml_backend_cpu_buffer_type()});
23902390
}
23912391
).set_env("LLAMA_ARG_CPU_MOE"));
23922392
add_opt(common_arg(

ggml/include/ggml.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@
304304
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
305305
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
306306

307+
#define GGML_TENSOR_TERNARY_OP_LOCALS \
308+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
309+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
310+
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
311+
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
312+
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
313+
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
314+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
315+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
316+
307317
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
308318
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
309319
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
@@ -440,6 +450,7 @@ extern "C" {
440450

441451
GGML_OP_DUP,
442452
GGML_OP_ADD,
453+
GGML_OP_ADD_ID,
443454
GGML_OP_ADD1,
444455
GGML_OP_ACC,
445456
GGML_OP_SUB,
@@ -834,6 +845,13 @@ extern "C" {
834845
struct ggml_tensor * b,
835846
enum ggml_type type);
836847

848+
// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
849+
GGML_API struct ggml_tensor * ggml_add_id(
850+
struct ggml_context * ctx,
851+
struct ggml_tensor * a,
852+
struct ggml_tensor * b,
853+
struct ggml_tensor * ids);
854+
837855
GGML_API struct ggml_tensor * ggml_add1(
838856
struct ggml_context * ctx,
839857
struct ggml_tensor * a,

ggml/src/ggml-alloc.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
2929
case GGML_OP_DIAG_MASK_ZERO:
3030
case GGML_OP_DIAG_MASK_INF:
3131
case GGML_OP_ADD:
32+
case GGML_OP_ADD_ID:
3233
case GGML_OP_ADD1:
3334
case GGML_OP_SUB:
3435
case GGML_OP_MUL:

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,6 +1676,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16761676
{
16771677
ggml_compute_forward_add(params, tensor);
16781678
} break;
1679+
case GGML_OP_ADD_ID:
1680+
{
1681+
ggml_compute_forward_add_id(params, tensor);
1682+
} break;
16791683
case GGML_OP_ADD1:
16801684
{
16811685
ggml_compute_forward_add1(params, tensor);
@@ -2117,6 +2121,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21172121
case GGML_OP_DUP:
21182122
case GGML_OP_CONT:
21192123
case GGML_OP_ADD:
2124+
case GGML_OP_ADD_ID:
21202125
case GGML_OP_ADD1:
21212126
case GGML_OP_ACC:
21222127
{
@@ -2680,6 +2685,7 @@ struct ggml_cplan ggml_graph_plan(
26802685
}
26812686
} break;
26822687
case GGML_OP_ADD:
2688+
case GGML_OP_ADD_ID:
26832689
case GGML_OP_ADD1:
26842690
{
26852691
if (ggml_is_quantized(node->src[0]->type)) {

ggml/src/ggml-cpu/ops.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,77 @@ void ggml_compute_forward_add(
13111311
}
13121312
}
13131313

1314+
// ggml_compute_forward_add_id
1315+
1316+
static void ggml_compute_forward_add_id_f32(
1317+
const ggml_compute_params * params,
1318+
ggml_tensor * dst) {
1319+
1320+
const ggml_tensor * src0 = dst->src[0];
1321+
const ggml_tensor * src1 = dst->src[1];
1322+
const ggml_tensor * src2 = dst->src[2];
1323+
1324+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1325+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
1326+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
1327+
GGML_ASSERT(src2->type == GGML_TYPE_I32);
1328+
1329+
GGML_ASSERT(src0->nb[0] == sizeof(float));
1330+
GGML_ASSERT(src1->nb[0] == sizeof(float));
1331+
1332+
const int ith = params->ith;
1333+
const int nth = params->nth;
1334+
1335+
const int nr = ggml_nrows(src0);
1336+
1337+
GGML_TENSOR_TERNARY_OP_LOCALS
1338+
1339+
GGML_ASSERT( nb0 == sizeof(float));
1340+
GGML_ASSERT(nb10 == sizeof(float));
1341+
1342+
// rows per thread
1343+
const int dr = (nr + nth - 1)/nth;
1344+
1345+
// row range for this thread
1346+
const int ir0 = dr*ith;
1347+
const int ir1 = MIN(ir0 + dr, nr);
1348+
1349+
for (int ir = ir0; ir < ir1; ++ir) {
1350+
// src0 indices
1351+
const int i3 = ir/(ne2*ne1);
1352+
const int i2 = (ir - i3*ne2*ne1)/ne1;
1353+
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1354+
1355+
// src1 indices
1356+
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
1357+
1358+
GGML_ASSERT(i11 >= 0 && i11 < ne11);
1359+
1360+
ggml_vec_add_f32(ne0,
1361+
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
1362+
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
1363+
(float *) ((char *) src1->data + i11*nb11));
1364+
}
1365+
}
1366+
1367+
void ggml_compute_forward_add_id(
1368+
const ggml_compute_params * params,
1369+
ggml_tensor * dst) {
1370+
1371+
const ggml_tensor * src0 = dst->src[0];
1372+
1373+
switch (src0->type) {
1374+
case GGML_TYPE_F32:
1375+
{
1376+
ggml_compute_forward_add_id_f32(params, dst);
1377+
} break;
1378+
default:
1379+
{
1380+
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
1381+
}
1382+
}
1383+
}
1384+
13141385
// ggml_compute_forward_add1
13151386

13161387
static void ggml_compute_forward_add1_f32(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ extern "C" {
2929

3030
void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
3131
void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
32+
void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
3233
void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
3334
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
3435
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-cpu/vec.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,22 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
5555

5656
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
5757
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
58-
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
58+
59+
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
60+
int i = 0;
61+
#if defined(__AVX2__)
62+
for (; i + 7 < n; i += 8) {
63+
__m256 vx = _mm256_loadu_ps(x + i);
64+
__m256 vy = _mm256_loadu_ps(y + i);
65+
__m256 vz = _mm256_add_ps(vx, vy);
66+
_mm256_storeu_ps(z + i, vz);
67+
}
68+
#endif
69+
for (; i < n; ++i) {
70+
z[i] = x[i] + y[i];
71+
}
72+
}
73+
5974
inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
6075
for (int i = 0; i < n; ++i) {
6176
z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));

ggml/src/ggml-cuda/add-id.cu

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#include "add-id.cuh"
2+
3+
static __global__ void add_id_kernel(
4+
const float * src0, const float * src1, const int32_t * src2, float * dst,
5+
int64_t ne0, int64_t ne1,
6+
size_t nb01, size_t nb02,
7+
size_t nb11,
8+
size_t nb21
9+
) {
10+
11+
const int64_t i1 = blockIdx.x;
12+
const int64_t i2 = blockIdx.y;
13+
14+
const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
15+
16+
const size_t nb1 = ne0 * sizeof(float);
17+
const size_t nb2 = ne1 * nb1;
18+
19+
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
20+
const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
21+
const float * src1_row = (const float *)((char *)src1 + i11*nb11);
22+
23+
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
24+
dst_row[i0] = src0_row[i0] + src1_row[i0];
25+
}
26+
}
27+
28+
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
29+
const ggml_tensor * src0 = dst->src[0];
30+
const ggml_tensor * src1 = dst->src[1];
31+
const ggml_tensor * src2 = dst->src[2];
32+
33+
GGML_TENSOR_TERNARY_OP_LOCALS
34+
35+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
36+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
37+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
38+
GGML_ASSERT(src2->type == GGML_TYPE_I32);
39+
40+
GGML_ASSERT(nb00 == sizeof(float));
41+
GGML_ASSERT(nb10 == sizeof(float));
42+
GGML_ASSERT(nb20 == sizeof(int32_t));
43+
44+
const float * src0_d = (const float *)src0->data;
45+
const float * src1_d = (const float *)src1->data;
46+
const int32_t * src2_d = (const int32_t *)src2->data;
47+
float * dst_d = (float *)dst->data;
48+
49+
int threads = std::min((int)ne00, 768); // cols
50+
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
51+
add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
52+
src0_d, src1_d, src2_d, dst_d,
53+
ne0, ne1,
54+
nb01, nb02,
55+
nb11,
56+
nb21
57+
);
58+
}

ggml/src/ggml-cuda/add-id.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "ggml-cuda/common.cuh"
66
#include "ggml-cuda/acc.cuh"
7+
#include "ggml-cuda/add-id.cuh"
78
#include "ggml-cuda/arange.cuh"
89
#include "ggml-cuda/argmax.cuh"
910
#include "ggml-cuda/argsort.cuh"
@@ -2259,6 +2260,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22592260
case GGML_OP_ADD1: // TODO: more efficient implementation
22602261
ggml_cuda_op_add(ctx, dst);
22612262
break;
2263+
case GGML_OP_ADD_ID:
2264+
ggml_cuda_op_add_id(ctx, dst);
2265+
break;
22622266
case GGML_OP_SUB:
22632267
ggml_cuda_op_sub(ctx, dst);
22642268
break;
@@ -3437,6 +3441,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
34373441
case GGML_OP_PERMUTE:
34383442
case GGML_OP_TRANSPOSE:
34393443
case GGML_OP_ADD:
3444+
case GGML_OP_ADD_ID:
34403445
case GGML_OP_ADD1:
34413446
case GGML_OP_SUB:
34423447
case GGML_OP_MUL:

0 commit comments

Comments
 (0)