Skip to content

Commit 2b7be22

Browse files
authored
Merge branch 'ggerganov:master' into k-shift2
2 parents 5144fd9 + e597e50 commit 2b7be22

File tree

9 files changed

+352
-107
lines changed

9 files changed

+352
-107
lines changed

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3261,7 +3261,7 @@ int main(int argc, char ** argv) {
32613261
ctx_server.queue_tasks.terminate();
32623262
};
32633263

3264-
LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
3264+
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
32653265

32663266
ctx_server.queue_tasks.start_loop();
32673267

ggml/src/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,7 @@ if (GGML_KOMPUTE)
800800
kompute-shaders/op_mul_mat_q8_0.comp
801801
kompute-shaders/op_mul_mat_q4_0.comp
802802
kompute-shaders/op_mul_mat_q4_1.comp
803+
kompute-shaders/op_mul_mat_q4_k.comp
803804
kompute-shaders/op_mul_mat_q6_k.comp
804805
kompute-shaders/op_getrows_f32.comp
805806
kompute-shaders/op_getrows_f16.comp
@@ -833,6 +834,7 @@ if (GGML_KOMPUTE)
833834
shaderop_mul_mat_q8_0.h
834835
shaderop_mul_mat_q4_0.h
835836
shaderop_mul_mat_q4_1.h
837+
shaderop_mul_mat_q4_k.h
836838
shaderop_mul_mat_q6_k.h
837839
shaderop_getrows_f32.h
838840
shaderop_getrows_f16.h
@@ -1400,7 +1402,7 @@ list(APPEND GGML_EXTRA_LIBS_PRIVATE Threads::Threads)
14001402

14011403
find_library(MATH_LIBRARY m)
14021404
if (MATH_LIBRARY)
1403-
if (NOT WIN32 OR NOT GGML_SYCL)
1405+
if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT})
14041406
list(APPEND GGML_EXTRA_LIBS_PRIVATE m)
14051407
endif()
14061408
endif()

ggml/src/ggml-backend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1508,7 +1508,7 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co
15081508
return -1;
15091509
}
15101510

1511-
#if 1
1511+
#if 0
15121512
#define GGML_SCHED_MAX_SPLITS_DEBUG 4096
15131513
static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
15141514
#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)

ggml/src/ggml-cuda.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3107,18 +3107,20 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31073107
}
31083108
return false;
31093109
} break;
3110+
case GGML_OP_NORM:
3111+
case GGML_OP_RMS_NORM:
3112+
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
3113+
break;
31103114
case GGML_OP_NONE:
31113115
case GGML_OP_RESHAPE:
31123116
case GGML_OP_VIEW:
31133117
case GGML_OP_PERMUTE:
31143118
case GGML_OP_TRANSPOSE:
3115-
case GGML_OP_NORM:
31163119
case GGML_OP_ADD:
31173120
case GGML_OP_ADD1:
31183121
case GGML_OP_SUB:
31193122
case GGML_OP_MUL:
31203123
case GGML_OP_DIV:
3121-
case GGML_OP_RMS_NORM:
31223124
case GGML_OP_SCALE:
31233125
case GGML_OP_SQR:
31243126
case GGML_OP_SQRT:

ggml/src/ggml-kompute.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "shaderop_mul_mat_q8_0.h"
2121
#include "shaderop_mul_mat_q4_0.h"
2222
#include "shaderop_mul_mat_q4_1.h"
23+
#include "shaderop_mul_mat_q4_k.h"
2324
#include "shaderop_mul_mat_q6_k.h"
2425
#include "shaderop_mul_mat_mat_f32.h"
2526
#include "shaderop_getrows_f32.h"
@@ -1067,6 +1068,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
10671068
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
10681069
}
10691070

1071+
static void ggml_vk_mul_mat_q4_k(
1072+
kp::Sequence& seq,
1073+
const std::shared_ptr<kp::Tensor>& inA,
1074+
const std::shared_ptr<kp::Tensor>& inB,
1075+
const std::shared_ptr<kp::Tensor>& out,
1076+
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1077+
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
1078+
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
1079+
int32_t ne1, int32_t r2, int32_t r3
1080+
) {
1081+
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
1082+
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
1083+
1084+
struct PushConstants {
1085+
uint32_t inAOff, inBOff, outOff;
1086+
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
1087+
} pushConsts {
1088+
0, 0, 0,
1089+
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
1090+
};
1091+
1092+
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1093+
if (!komputeManager()->hasAlgorithm(__func__)) {
1094+
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
1095+
} else {
1096+
s_algo = komputeManager()->getAlgorithm(__func__);
1097+
s_algo->setTensors({inA, inB, out});
1098+
s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
1099+
s_algo->setPushConstants<PushConstants>({pushConsts});
1100+
s_algo->updateDescriptors(s_kompute_context->pool.get());
1101+
}
1102+
seq.record<kp::OpAlgoDispatch>(s_algo);
1103+
}
1104+
10701105
static void ggml_vk_mul_mat_q6_k(
10711106
kp::Sequence& seq,
10721107
const std::shared_ptr<kp::Tensor>& inA,
@@ -1384,6 +1419,7 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
13841419
case GGML_TYPE_Q8_0:
13851420
case GGML_TYPE_Q4_0:
13861421
case GGML_TYPE_Q4_1:
1422+
case GGML_TYPE_Q4_K:
13871423
return true;
13881424
default:
13891425
;
@@ -1635,6 +1671,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
16351671
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
16361672
);
16371673
break;
1674+
case GGML_TYPE_Q4_K:
1675+
ggml_vk_mul_mat_q4_k(
1676+
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1677+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
1678+
);
1679+
break;
16381680
case GGML_TYPE_Q6_K:
16391681
ggml_vk_mul_mat_q6_k(
16401682
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,

ggml/src/ggml.c

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7272,6 +7272,7 @@ struct ggml_tensor * ggml_ssm_conv(
72727272
const int64_t n_s = sx->ne[2];
72737273

72747274
// TODO: maybe support other strides than 1?
7275+
// FIXME: this is always true?
72757276
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
72767277
GGML_ASSERT(sx->ne[1] == d_inner);
72777278
GGML_ASSERT(n_t >= 0);
@@ -22102,18 +22103,46 @@ static size_t gguf_type_size(enum gguf_type type) {
2210222103
return GGUF_TYPE_SIZE[type];
2210322104
}
2210422105

22105-
static void gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
22106-
GGML_ASSERT(info->n_dims <= GGML_MAX_DIMS);
22107-
GGML_ASSERT(0 <= info->type && info->type < GGML_TYPE_COUNT);
22106+
static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
22107+
if (info->n_dims > GGML_MAX_DIMS) {
22108+
fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
22109+
return false;
22110+
}
22111+
22112+
if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
22113+
fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
22114+
return false;
22115+
}
22116+
22117+
if (strlen(info->name.data) >= GGML_MAX_NAME) {
22118+
fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
22119+
return false;
22120+
}
2210822121

2210922122
for (uint32_t i = 0; i < info->n_dims; ++i) {
22110-
GGML_ASSERT(info->ne[i] > 0);
22123+
if (info->ne[i] <= 0) {
22124+
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
22125+
return false;
22126+
}
2211122127
}
2211222128

2211322129
// prevent overflow for total number of elements
22114-
GGML_ASSERT(INT64_MAX/info->ne[1] > info->ne[0]);
22115-
GGML_ASSERT(INT64_MAX/info->ne[2] > info->ne[0]*info->ne[1]);
22116-
GGML_ASSERT(INT64_MAX/info->ne[3] > info->ne[0]*info->ne[1]*info->ne[2]);
22130+
if (INT64_MAX/info->ne[1] <= info->ne[0]) {
22131+
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
22132+
return false;
22133+
}
22134+
22135+
if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
22136+
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
22137+
return false;
22138+
}
22139+
22140+
if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
22141+
fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
22142+
return false;
22143+
}
22144+
22145+
return true;
2211722146
}
2211822147

2211922148
static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
@@ -22414,8 +22443,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
2241422443
ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
2241522444
ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
2241622445

22417-
// TODO: return an error instead of crashing with GGML_ASSERT
22418-
gguf_tensor_info_sanitize(info);
22446+
ok = ok && gguf_tensor_info_sanitize(info);
2241922447

2242022448
// make sure there is no duplicated tensor names
2242122449
for (uint64_t j = 0; j < i && ok; ++j) {

ggml/src/kompute-shaders/common.comp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define TWOPI_F 6.283185307179586f
1616

1717
#define QK_K 256
18+
#define K_SCALE_SIZE 12
1819

1920
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
2021
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
@@ -64,6 +65,14 @@ mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
6465
return reg;
6566
}
6667

68+
#define sizeof_block_q4_k 144
69+
struct block_q4_k {
70+
float16_t d;
71+
float16_t dmin;
72+
uint8_t scales[K_SCALE_SIZE];
73+
uint8_t qs[QK_K/2];
74+
};
75+
6776
#define sizeof_block_q6_k 210
6877
struct block_q6_k {
6978
uint8_t ql[QK_K/2]; // quants, lower 4 bits
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#version 450
2+
3+
#include "common.comp"
4+
5+
#define N_DST 4
6+
#define SIZE_OF_BLOCK sizeof_block_q4_k
7+
8+
layout(local_size_x = 4) in;
9+
layout(local_size_y = 8) in;
10+
layout(local_size_z = 1) in;
11+
12+
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
13+
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
14+
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
15+
16+
layout (push_constant) uniform parameter {
17+
uint inAOff;
18+
uint inBOff;
19+
uint outOff;
20+
int ne00;
21+
int ne10;
22+
int ne0;
23+
int ne1;
24+
int ne01;
25+
int ne02;
26+
int ne12;
27+
int r2;
28+
int r3;
29+
} pcs;
30+
31+
void main() {
32+
const uint16_t kmask1 = uint16_t(0x3f3f);
33+
const uint16_t kmask2 = uint16_t(0x0f0f);
34+
const uint16_t kmask3 = uint16_t(0xc0c0);
35+
36+
const uint ix = gl_SubgroupInvocationID/8; // 0...3
37+
const uint it = gl_SubgroupInvocationID%8; // 0...7
38+
const uint iq = it/4; // 0 or 1
39+
const uint ir = it%4; // 0...3
40+
41+
const uint nb = pcs.ne00/QK_K;
42+
43+
const uint r0 = gl_WorkGroupID.x;
44+
const uint r1 = gl_WorkGroupID.y;
45+
const uint im = gl_WorkGroupID.z;
46+
47+
const uint first_row = r0 * N_DST;
48+
const uint ib_row = first_row * nb;
49+
50+
const uint i12 = im%pcs.ne12;
51+
const uint i13 = im/pcs.ne12;
52+
53+
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
54+
55+
const uint xblk = ib_row + offset0 + pcs.inAOff;
56+
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
57+
58+
float yl[16];
59+
float yh[16];
60+
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
61+
float all_sum = 0.f;
62+
63+
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
64+
65+
for (uint ib = ix; ib < nb; ib += 4) {
66+
const uint blk_idx = ib + xblk;
67+
68+
float sumy[4] = {0.f, 0.f, 0.f, 0.f};
69+
for (int i = 0; i < 8; ++i) {
70+
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
71+
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
72+
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
73+
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
74+
}
75+
76+
for (int row = 0; row < N_DST; row++) {
77+
uint row_idx = row * nb;
78+
79+
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
80+
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
81+
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
82+
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
83+
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
84+
85+
uint16_t sc16[4];
86+
sc16[0] = sc_0 & kmask1;
87+
sc16[1] = sc_2 & kmask1;
88+
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
89+
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
90+
91+
float acc1[4] = {0.f, 0.f, 0.f, 0.f};
92+
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
93+
for (int i = 0; i < 8; i += 2) {
94+
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
95+
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
96+
acc1[0] += yl[i+0] * (q1 & 0x000F);
97+
acc1[1] += yl[i+1] * (q1 & 0x0F00);
98+
acc1[2] += yl[i+8] * (q1 & 0x00F0);
99+
acc1[3] += yl[i+9] * (q1 & 0xF000);
100+
acc2[0] += yh[i+0] * (q2 & 0x000F);
101+
acc2[1] += yh[i+1] * (q2 & 0x0F00);
102+
acc2[2] += yh[i+8] * (q2 & 0x00F0);
103+
acc2[3] += yh[i+9] * (q2 & 0xF000);
104+
}
105+
106+
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
107+
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
108+
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
109+
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
110+
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
111+
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
112+
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
113+
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
114+
115+
float dall = float(inA[blk_idx + row_idx].d);
116+
float dmin = float(inA[blk_idx + row_idx].dmin);
117+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
118+
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
119+
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
120+
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
121+
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
122+
}
123+
124+
y4 += 4 * QK_K;
125+
}
126+
127+
for (int row = 0; row < N_DST; ++row) {
128+
all_sum = subgroupAdd(sumf[row]);
129+
if (subgroupElect()) {
130+
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
131+
}
132+
}
133+
}

0 commit comments

Comments
 (0)