Skip to content

Commit 71d8492

Browse files
committed
Merge branch 'master' into imatrix
2 parents 98bcd3e + 982e347 commit 71d8492

File tree

23 files changed

+858
-162
lines changed

23 files changed

+858
-162
lines changed

CMakePresets.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@
5555
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake"
5656
}
5757
},
58+
{
59+
"name": "x64-linux-gcc", "hidden": true,
60+
"cacheVariables": {
61+
"CMAKE_C_COMPILER": "gcc",
62+
"CMAKE_CXX_COMPILER": "g++"
63+
}
64+
},
65+
{ "name": "x64-linux-gcc-debug", "inherits": [ "base", "x64-linux-gcc", "debug" ] },
66+
{ "name": "x64-linux-gcc-release", "inherits": [ "base", "x64-linux-gcc", "release" ] },
67+
{ "name": "x64-linux-gcc-reldbg", "inherits": [ "base", "x64-linux-gcc", "reldbg" ] },
68+
{ "name": "x64-linux-gcc+static-release", "inherits": [ "base", "x64-linux-gcc", "release", "static" ] },
5869

5970
{ "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] },
6071
{ "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] },

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
133133
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
134134
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
135135
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
136+
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
136137

137138
#### Multimodal
138139

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,6 +2090,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
20902090
{
20912091
// TODO: add support
20922092
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
2093+
#pragma message("TODO: implement F32, F16, BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
20932094
return false;
20942095
} break;
20952096
case GGML_OP_CPY: {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "ggml-cuda/upscale.cuh"
4444
#include "ggml-cuda/wkv.cuh"
4545
#include "ggml-cuda/gla.cuh"
46+
#include "ggml-cuda/set-rows.cuh"
4647
#include "ggml.h"
4748

4849
#include <algorithm>
@@ -2230,6 +2231,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22302231
case GGML_OP_GET_ROWS_BACK:
22312232
ggml_cuda_op_get_rows_back(ctx, dst);
22322233
break;
2234+
case GGML_OP_SET_ROWS:
2235+
ggml_cuda_op_set_rows(ctx, dst);
2236+
break;
22332237
case GGML_OP_DUP:
22342238
ggml_cuda_dup(ctx, dst);
22352239
break;
@@ -2299,6 +2303,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22992303
case GGML_UNARY_OP_EXP:
23002304
ggml_cuda_op_exp(ctx, dst);
23012305
break;
2306+
case GGML_UNARY_OP_ELU:
2307+
ggml_cuda_op_elu(ctx, dst);
2308+
break;
23022309
default:
23032310
return false;
23042311
}
@@ -3112,6 +3119,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31123119
case GGML_UNARY_OP_GELU_QUICK:
31133120
case GGML_UNARY_OP_TANH:
31143121
case GGML_UNARY_OP_EXP:
3122+
case GGML_UNARY_OP_ELU:
31153123
return ggml_is_contiguous(op->src[0]);
31163124
default:
31173125
return false;
@@ -3216,6 +3224,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32163224
{
32173225
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
32183226
} break;
3227+
case GGML_OP_SET_ROWS:
3228+
{
3229+
#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
3230+
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) &&
3231+
op->src[0]->type == GGML_TYPE_F32 &&
3232+
op->src[1]->type == GGML_TYPE_I64;
3233+
} break;
32193234
case GGML_OP_CPY:
32203235
{
32213236
ggml_type src0_type = op->src[0]->type;

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#include "set-rows.cuh"
2+
3+
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
4+
5+
template<typename src_t, typename dst_t>
6+
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {}
7+
8+
template<>
9+
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
10+
*dst_h = __float2half(*src_f);
11+
}
12+
13+
template<>
14+
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
15+
*dst_b = *src_f;
16+
}
17+
18+
template<>
19+
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
20+
*dst_f = *src_f;
21+
}
22+
23+
template<typename src_t, typename dst_t>
24+
static __global__ void k_set_rows(
25+
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
26+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
27+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
28+
const int64_t s01, const int64_t s02, const int64_t s03,
29+
const int64_t s10, const int64_t s11, const int64_t s12,
30+
const int64_t s1, const int64_t s2, const int64_t s3) {
31+
32+
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
33+
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
34+
35+
if (i >= ne_total) {
36+
return;
37+
}
38+
39+
const int64_t i03 = i / (ne00 * ne01 * ne02);
40+
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
41+
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
42+
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
43+
44+
const int64_t i12 = i03 % ne12;
45+
const int64_t i11 = i02 % ne11;
46+
const int64_t i10 = i01;
47+
48+
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
49+
50+
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
51+
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
52+
53+
const src_t* src_elem = src0_row + i00;
54+
dst_t* dst_elem = dst_row_ptr + i00;
55+
set_rows_1(src_elem, dst_elem);
56+
}
57+
58+
template<typename src_t, typename dst_t>
59+
static void set_rows_cuda(
60+
const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
61+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
62+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
63+
const size_t nb01, const size_t nb02, const size_t nb03,
64+
const size_t nb10, const size_t nb11, const size_t nb12,
65+
const size_t nb1, const size_t nb2, const size_t nb3,
66+
cudaStream_t stream) {
67+
68+
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
69+
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
70+
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
71+
const dim3 grid_size(num_blocks);
72+
73+
74+
const int64_t s01 = nb01/sizeof(src_t);
75+
const int64_t s02 = nb02/sizeof(src_t);
76+
const int64_t s03 = nb03/sizeof(src_t);
77+
const int64_t s10 = nb10/sizeof(int64_t);
78+
const int64_t s11 = nb11/sizeof(int64_t);
79+
const int64_t s12 = nb12/sizeof(int64_t);
80+
const int64_t s1 = nb1/sizeof(dst_t);
81+
const int64_t s2 = nb2/sizeof(dst_t);
82+
const int64_t s3 = nb3/sizeof(dst_t);
83+
84+
if (ne_total > 0) {
85+
k_set_rows<<<grid_size, block_size, 0, stream>>>(
86+
src0_d, src1_d, dst_d,
87+
ne00, ne01, ne02, ne03,
88+
ne10, ne11, ne12, ne13,
89+
s01, s02, s03,
90+
s10, s11, s12,
91+
s1, s2, s3);
92+
}
93+
}
94+
95+
96+
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
97+
const ggml_tensor * src0 = dst->src[0];
98+
const ggml_tensor * src1 = dst->src[1];
99+
100+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
101+
GGML_ASSERT(src1->type == GGML_TYPE_I64);
102+
103+
GGML_TENSOR_BINARY_OP_LOCALS
104+
105+
const float * src0_d = (const float *)src0->data;
106+
const int64_t * src1_d = (const int64_t *)src1->data;
107+
108+
cudaStream_t stream = ctx.stream();
109+
110+
111+
112+
if (dst->type == GGML_TYPE_F32) {
113+
set_rows_cuda(
114+
src0_d, src1_d, (float*)dst->data,
115+
ne00, ne01, ne02, ne03,
116+
ne10, ne11, ne12, ne13,
117+
nb01, nb02, nb03,
118+
nb10, nb11, nb12,
119+
nb1, nb2, nb3,
120+
stream
121+
);
122+
} else if (dst->type == GGML_TYPE_F16) {
123+
set_rows_cuda(
124+
src0_d, src1_d, (half*)dst->data,
125+
ne00, ne01, ne02, ne03,
126+
ne10, ne11, ne12, ne13,
127+
nb01, nb02, nb03,
128+
nb10, nb11, nb12,
129+
nb1, nb2, nb3,
130+
stream
131+
);
132+
} else if (dst->type == GGML_TYPE_BF16) {
133+
set_rows_cuda(
134+
src0_d, src1_d, (nv_bfloat16*)dst->data,
135+
ne00, ne01, ne02, ne03,
136+
ne10, ne11, ne12, ne13,
137+
nb01, nb02, nb03,
138+
nb10, nb11, nb12,
139+
nb1, nb2, nb3,
140+
stream
141+
);
142+
} else {
143+
GGML_ABORT("unsupported type");
144+
}
145+
}

ggml/src/ggml-cuda/set-rows.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include "common.cuh"
4+
5+
#define CUDA_SET_ROWS_BLOCK_SIZE 256
6+
7+
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/unary.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) {
8383
return logf(x);
8484
}
8585

86+
static __device__ __forceinline__ float op_elu(float x) {
87+
return (x > 0.f) ? x : expm1f(x);
88+
}
89+
8690
template <float (*op)(float), typename T>
8791
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
8892
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
196200
ggml_cuda_op_unary<op_log>(ctx, dst);
197201
}
198202

203+
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
204+
ggml_cuda_op_unary<op_elu>(ctx, dst);
205+
}
199206
/* gated ops */
200207

201208
template <float (*op)(float), typename T>

ggml/src/ggml-cuda/unary.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
5959

6060
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6161

62+
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
63+
6264
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6365

6466
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
173173
GGML_METAL_KERNEL_TYPE_SILU,
174174
GGML_METAL_KERNEL_TYPE_SILU_4,
175175
GGML_METAL_KERNEL_TYPE_ELU,
176+
GGML_METAL_KERNEL_TYPE_ABS,
177+
GGML_METAL_KERNEL_TYPE_SGN,
178+
GGML_METAL_KERNEL_TYPE_STEP,
179+
GGML_METAL_KERNEL_TYPE_HARDSWISH,
180+
GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
181+
GGML_METAL_KERNEL_TYPE_EXP,
176182
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
177183
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
178184
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -1155,6 +1161,12 @@ @implementation GGMLMetalClass
11551161
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
11561162
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
11571163
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
1164+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
1165+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
1166+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
1167+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
1168+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
1169+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
11581170
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
11591171
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
11601172
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16881700
case GGML_UNARY_OP_SILU:
16891701
case GGML_UNARY_OP_ELU:
16901702
case GGML_UNARY_OP_NEG:
1703+
case GGML_UNARY_OP_ABS:
1704+
case GGML_UNARY_OP_SGN:
1705+
case GGML_UNARY_OP_STEP:
1706+
case GGML_UNARY_OP_HARDSWISH:
1707+
case GGML_UNARY_OP_HARDSIGMOID:
1708+
case GGML_UNARY_OP_EXP:
16911709
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
16921710
default:
16931711
return false;
@@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node(
24392457

24402458
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
24412459
} break;
2460+
case GGML_UNARY_OP_ABS:
2461+
{
2462+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
2463+
2464+
[encoder setComputePipelineState:pipeline];
2465+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2466+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2467+
2468+
const int64_t n = ggml_nelements(dst);
2469+
2470+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2471+
} break;
2472+
case GGML_UNARY_OP_SGN:
2473+
{
2474+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
2475+
2476+
[encoder setComputePipelineState:pipeline];
2477+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2478+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2479+
2480+
const int64_t n = ggml_nelements(dst);
2481+
2482+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2483+
} break;
2484+
case GGML_UNARY_OP_STEP:
2485+
{
2486+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
2487+
2488+
[encoder setComputePipelineState:pipeline];
2489+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2490+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2491+
2492+
const int64_t n = ggml_nelements(dst);
2493+
2494+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2495+
} break;
2496+
case GGML_UNARY_OP_HARDSWISH:
2497+
{
2498+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
2499+
2500+
[encoder setComputePipelineState:pipeline];
2501+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2502+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2503+
2504+
const int64_t n = ggml_nelements(dst);
2505+
2506+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2507+
} break;
2508+
case GGML_UNARY_OP_HARDSIGMOID:
2509+
{
2510+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
2511+
2512+
[encoder setComputePipelineState:pipeline];
2513+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2514+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2515+
2516+
const int64_t n = ggml_nelements(dst);
2517+
2518+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2519+
} break;
2520+
case GGML_UNARY_OP_EXP:
2521+
{
2522+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
2523+
2524+
[encoder setComputePipelineState:pipeline];
2525+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2526+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2527+
2528+
const int64_t n = ggml_nelements(dst);
2529+
2530+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2531+
} break;
24422532
default:
24432533
{
24442534
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

0 commit comments

Comments
 (0)