Skip to content

Commit 09263e0

Browse files
authored
Merge branch 'ggml-org:master' into mradermacher
2 parents 3092176 + 77d5e9a commit 09263e0

File tree

17 files changed

+256
-160
lines changed

17 files changed

+256
-160
lines changed

common/json-schema-to-grammar.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ using json = nlohmann::ordered_json;
1616
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
1717
auto has_max = max_items != std::numeric_limits<int>::max();
1818

19+
if (max_items == 0) {
20+
return "";
21+
}
1922
if (min_items == 0 && max_items == 1) {
2023
return item_rule + "?";
2124
}

examples/json_schema_to_grammar.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
1212

13+
if max_items == 0:
14+
return ""
15+
1316
if min_items == 0 and max_items == 1:
1417
return f'{item_rule}?'
1518

examples/llava/clip.cpp

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -554,15 +554,15 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
554554
}
555555

556556
// implementation of the 2D RoPE without adding a new op in ggml
557+
// this is not efficient (use double the memory), but works on all backends
558+
// TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
557559
static ggml_tensor * build_rope_2d(
558-
ggml_cgraph * gf,
559560
ggml_context * ctx0,
560561
ggml_tensor * cur,
561562
ggml_tensor * pos_h,
562563
ggml_tensor * pos_w,
563564
const float freq_base
564565
) {
565-
ggml_tensor * tmp;
566566
const int64_t n_dim = cur->ne[0];
567567
const int64_t n_head = cur->ne[1];
568568
const int64_t n_pos = cur->ne[2];
@@ -571,18 +571,23 @@ static ggml_tensor * build_rope_2d(
571571
// we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
572572
// first half of cur will use 1e-0, 1e-2 (even)
573573
// second half of cur will use 1e-1, 1e-3 (odd)
574-
//
575-
// for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
574+
// the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
576575
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
577576
// then for the second half, we use freq_scale to shift the inv_freq
578577
// ^ why? replace (2i) with (2i+1) in the above equation
579578
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
580579

581580
// first half
581+
ggml_tensor * first;
582582
{
583-
cur = ggml_rope_ext_inplace(
583+
first = ggml_view_3d(ctx0, cur,
584+
n_dim/2, n_head, n_pos,
585+
ggml_row_size(cur->type, n_dim),
586+
ggml_row_size(cur->type, n_dim*n_head),
587+
0);
588+
first = ggml_rope_ext(
584589
ctx0,
585-
cur,
590+
first,
586591
pos_h, // positions
587592
nullptr, // freq factors
588593
n_dim/2, // n_dims
@@ -592,26 +597,27 @@ static ggml_tensor * build_rope_2d(
592597
}
593598

594599
// second half
600+
ggml_tensor * second;
595601
{
596-
tmp = ggml_view_3d(ctx0, cur,
602+
second = ggml_view_3d(ctx0, cur,
597603
n_dim/2, n_head, n_pos,
598604
ggml_row_size(cur->type, n_dim),
599605
ggml_row_size(cur->type, n_dim*n_head),
600606
n_dim/2 * ggml_element_size(cur));
601-
tmp = ggml_rope_ext_inplace(
607+
second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors
608+
second = ggml_rope_ext(
602609
ctx0,
603-
tmp,
610+
second,
604611
pos_w, // positions
605612
nullptr, // freq factors
606613
n_dim/2, // n_dims
607614
0, 0, freq_base,
608615
freq_scale_odd,
609616
0.0f, 1.0f, 0.0f, 0.0f
610617
);
611-
// calculate inplace (modify cur directly)
612-
ggml_build_forward_expand(gf, tmp);
613618
}
614619

620+
cur = ggml_concat(ctx0, first, second, 0);
615621
return cur;
616622
}
617623

@@ -680,13 +686,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
680686
struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
681687

682688
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
683-
Q = build_rope_2d(gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta);
689+
Q = build_rope_2d(ctx0, Q, pos_h, pos_w, hparams.rope_theta);
684690
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
685691

686692
struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
687693

688694
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
689-
K = build_rope_2d(gf, ctx0, K, pos_h, pos_w, hparams.rope_theta);
695+
K = build_rope_2d(ctx0, K, pos_h, pos_w, hparams.rope_theta);
690696
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
691697

692698
struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
@@ -2796,10 +2802,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
27962802
const auto & model = ctx->vision_model;
27972803
const auto & hparams = model.hparams;
27982804

2805+
// TODO @ngxson : this is ugly, need to refactor later
2806+
bool support_dynamic_size = ctx->has_minicpmv_projector
2807+
|| ctx->has_qwen2vl_merger
2808+
|| ctx->proj_type == PROJECTOR_TYPE_PIXTRAL;
2809+
27992810
const int image_size = hparams.image_size;
28002811
int image_size_width = image_size;
28012812
int image_size_height = image_size;
2802-
if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
2813+
if (support_dynamic_size) {
28032814
image_size_width = imgs.entries[0]->nx;
28042815
image_size_height = imgs.entries[0]->ny;
28052816
}
@@ -2811,9 +2822,20 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
28112822

28122823
{
28132824
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
2814-
float * data = (float *)malloc(ggml_nbytes(inp_raw));
2825+
std::vector<float> inp_data(ggml_nelements(inp_raw));
2826+
float * data = inp_data.data();
2827+
2828+
// layout of data (note: the channel dim is unrolled to better visualize the layout):
2829+
//
2830+
// ┌──W──┐
2831+
// │ H │ channel = R
2832+
// ├─────┤ │
2833+
// │ H │ channel = G
2834+
// ├─────┤ │
2835+
// │ H │ channel = B
2836+
// └─────┘ │
2837+
// ──────┘ x B
28152838

2816-
// TODO @ngxson : this whole code block is ugly, will need to be refactored
28172839
for (size_t i = 0; i < imgs.entries.size(); i++) {
28182840
const int nx = imgs.entries[i]->nx;
28192841
const int ny = imgs.entries[i]->ny;
@@ -2828,17 +2850,19 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
28282850
const int n = nx * ny;
28292851

28302852
for (int b = 0; b < batch_size; b++) {
2831-
for (int k = 0; k < 3; k++) {
2832-
for (int y = 0; y < ny; y++) {
2833-
for (int x = 0; x < nx; x++) {
2834-
data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k];
2835-
}
2853+
float * batch_entry = data + b * (3*n);
2854+
for (int y = 0; y < ny; y++) {
2855+
for (int x = 0; x < nx; x++) {
2856+
size_t base_src = 3*(y * nx + x); // idx of the first channel
2857+
size_t base_dst = y * nx + x; // idx of the first channel
2858+
batch_entry[ base_dst] = imgs.entries[b]->buf[base_src ];
2859+
batch_entry[1*n + base_dst] = imgs.entries[b]->buf[base_src + 1];
2860+
batch_entry[2*n + base_dst] = imgs.entries[b]->buf[base_src + 2];
28362861
}
28372862
}
28382863
}
28392864
}
28402865
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
2841-
free(data);
28422866
}
28432867
if (ctx->has_minicpmv_projector) {
28442868
{

examples/server/public_legacy/json-schema-to-grammar.mjs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
const SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}';
33

44
function _buildRepetition(itemRule, minItems, maxItems, opts={}) {
5+
if (maxItems == 0) {
6+
return '';
7+
}
58
if (minItems === 0 && maxItems === 1) {
69
return `${itemRule}?`;
710
}

ggml/include/ggml-cpu.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,11 @@ extern "C" {
133133

134134
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
135135

136+
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
137+
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
138+
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
139+
GGML_BACKEND_API void ggml_cpu_bf16_to_fp32(const ggml_bf16_t *, float *, int64_t);
140+
136141
#ifdef __cplusplus
137142
}
138143
#endif

ggml/include/ggml-rpc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
extern "C" {
88
#endif
99

10-
#define RPC_PROTO_MAJOR_VERSION 1
10+
#define RPC_PROTO_MAJOR_VERSION 2
1111
#define RPC_PROTO_MINOR_VERSION 0
1212
#define RPC_PROTO_PATCH_VERSION 0
1313
#define GGML_RPC_MAX_SERVERS 16

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
215215
.nrows = 1,
216216
},
217217
[GGML_TYPE_F16] = {
218-
.from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
218+
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp16,
219219
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
220220
.vec_dot_type = GGML_TYPE_F16,
221221
.nrows = 1,
@@ -356,7 +356,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
356356
.from_float = quantize_row_q8_K,
357357
},
358358
[GGML_TYPE_BF16] = {
359-
.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
359+
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_bf16,
360360
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
361361
.vec_dot_type = GGML_TYPE_BF16,
362362
.nrows = 1,
@@ -3166,6 +3166,93 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
31663166
return ggml_graph_compute(cgraph, &cplan);
31673167
}
31683168

3169+
void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
3170+
int64_t i = 0;
3171+
#if defined(__F16C__)
3172+
#if defined(__AVX512F__)
3173+
for (; i + 15 < n; i += 16) {
3174+
__m512 x_vec = _mm512_loadu_ps(x + i);
3175+
__m256i y_vec = _mm512_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
3176+
_mm256_storeu_si256((__m256i *)(y + i), y_vec);
3177+
}
3178+
#endif
3179+
for (; i + 7 < n; i += 8) {
3180+
__m256 x_vec = _mm256_loadu_ps(x + i);
3181+
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
3182+
_mm_storeu_si128((__m128i *)(y + i), y_vec);
3183+
}
3184+
for (; i + 3 < n; i += 4) {
3185+
__m128 x_vec = _mm_loadu_ps(x + i);
3186+
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
3187+
_mm_storel_epi64((__m128i *)(y + i), y_vec);
3188+
}
3189+
#endif
3190+
for (; i < n; ++i) {
3191+
y[i] = GGML_FP32_TO_FP16(x[i]);
3192+
}
3193+
}
3194+
3195+
void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
3196+
int64_t i = 0;
3197+
#if defined(__F16C__)
3198+
#if defined(__AVX512F__)
3199+
for (; i + 15 < n; i += 16) {
3200+
__m256i x_vec = _mm256_loadu_si256((const __m256i *)(x + i));
3201+
__m512 y_vec = _mm512_cvtph_ps(x_vec);
3202+
_mm512_storeu_ps(y + i, y_vec);
3203+
}
3204+
#endif
3205+
for (; i + 7 < n; i += 8) {
3206+
__m128i x_vec = _mm_loadu_si128((const __m128i *)(x + i));
3207+
__m256 y_vec = _mm256_cvtph_ps(x_vec);
3208+
_mm256_storeu_ps(y + i, y_vec);
3209+
}
3210+
for (; i + 3 < n; i += 4) {
3211+
__m128i x_vec = _mm_loadl_epi64((const __m128i *)(x + i));
3212+
__m128 y_vec = _mm_cvtph_ps(x_vec);
3213+
_mm_storeu_ps(y + i, y_vec);
3214+
}
3215+
#endif
3216+
for (; i < n; ++i) {
3217+
y[i] = GGML_FP16_TO_FP32(x[i]);
3218+
}
3219+
}
3220+
3221+
void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) {
3222+
int64_t i = 0;
3223+
for (; i < n; ++i) {
3224+
y[i] = GGML_FP32_TO_BF16(x[i]);
3225+
}
3226+
}
3227+
3228+
void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
3229+
int64_t i = 0;
3230+
#if defined(__AVX2__)
3231+
#if defined(__AVX512F__)
3232+
for (; i + 15 < n; i += 16) {
3233+
_mm512_storeu_ps(y + i,
3234+
_mm512_castsi512_ps(
3235+
_mm512_slli_epi32(
3236+
_mm512_cvtepu16_epi32(
3237+
_mm256_loadu_si256(
3238+
(const __m256i *)(x + i))),
3239+
16)));
3240+
}
3241+
#endif
3242+
for (; i + 7 < n; i += 8) {
3243+
_mm256_storeu_ps(y + i,
3244+
_mm256_castsi256_ps(
3245+
_mm256_slli_epi32(
3246+
_mm256_cvtepu16_epi32(
3247+
_mm_loadu_si128(
3248+
(const __m128i *)(x + i))),
3249+
16)));
3250+
}
3251+
#endif
3252+
for (; i < n; i++) {
3253+
y[i] = GGML_BF16_TO_FP32(x[i]);
3254+
}
3255+
}
31693256

31703257
int ggml_cpu_has_avx(void) {
31713258
#if defined(__AVX__)

ggml/src/ggml-cpu/ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4222,7 +4222,7 @@ static void ggml_compute_forward_get_rows_f16(
42224222

42234223
GGML_ASSERT(i01 >= 0 && i01 < ne01);
42244224

4225-
ggml_fp16_to_fp32_row(
4225+
ggml_cpu_fp16_to_fp32(
42264226
(const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
42274227
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
42284228
}
@@ -4263,7 +4263,7 @@ static void ggml_compute_forward_get_rows_bf16(
42634263

42644264
GGML_ASSERT(i01 >= 0 && i01 < ne01);
42654265

4266-
ggml_bf16_to_fp32_row(
4266+
ggml_cpu_bf16_to_fp32(
42674267
(const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
42684268
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
42694269
}

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
378378
}
379379

380380
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
381-
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
382-
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
381+
// No response
382+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
383383
uint8_t cmd_byte = cmd;
384384
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
385385
return false;
@@ -390,6 +390,15 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
390390
if (!send_data(sock->fd, input, input_size)) {
391391
return false;
392392
}
393+
return true;
394+
}
395+
396+
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
397+
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
398+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
399+
if (!send_rpc_cmd(sock, cmd, input, input_size)) {
400+
return false;
401+
}
393402
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
394403
// even if we do, we can skip sending output_size from the server for commands with known output size
395404
uint64_t out_size;
@@ -555,7 +564,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
555564
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
556565
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
557566
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
558-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
567+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
559568
GGML_ASSERT(status);
560569
}
561570

@@ -1428,9 +1437,6 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
14281437
if (!server.set_tensor(input)) {
14291438
return;
14301439
}
1431-
if (!send_msg(sockfd, nullptr, 0)) {
1432-
return;
1433-
}
14341440
break;
14351441
}
14361442
case RPC_CMD_SET_TENSOR_HASH: {

ggml/src/ggml-sycl/common.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,6 @@ struct ggml_backend_sycl_context {
313313
int device;
314314
std::string name;
315315
optimize_feature opt_feature;
316-
bool optimized_graph=false;
317316

318317
queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } };
319318

0 commit comments

Comments
 (0)