Skip to content

Commit 75ebba1

Browse files
authored
Merge branch 'ggerganov:master' into master2
2 parents 9b8fcda + 2194200 commit 75ebba1

File tree

9 files changed

+93
-39
lines changed

9 files changed

+93
-39
lines changed

common/json-schema-to-grammar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ class SchemaConverter {
611611
}
612612
return join_seq();
613613
};
614-
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
614+
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
615615
}
616616

617617
/*

examples/json_schema_to_grammar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def join_seq():
540540
return self._add_rule(
541541
name,
542542
to_rule(transform()) if self._raw_pattern \
543-
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
543+
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
544544

545545

546546
def _resolve_ref(self, ref):

examples/llava/llava.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * c
432432
bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos);
433433
if (!image_embed_result) {
434434
clip_image_u8_free(img);
435-
LOG_ERR("%s: coulnd't embed the image\n", __func__);
435+
LOG_ERR("%s: couldn't embed the image\n", __func__);
436436
return NULL;
437437
}
438438

examples/server/public/json-schema-to-grammar.mjs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ export class SchemaConverter {
529529
return joinSeq();
530530
};
531531

532-
return this._addRule(name, "\"\\\"\" " + toRule(transform()) + " \"\\\"\" space")
532+
return this._addRule(name, "\"\\\"\" (" + toRule(transform()) + ") \"\\\"\" space")
533533
}
534534

535535
_notStrings(strings) {

ggml/src/ggml-backend.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,6 @@ ggml_backend_t ggml_backend_init_best(void) {
682682

683683
// backend CPU
684684

685-
static const size_t TENSOR_ALIGNMENT = 32; // required for mmap as gguf only guarantees 32-byte alignment
686-
687685
static const char * ggml_backend_cpu_buffer_get_name(ggml_backend_buffer_t buffer) {
688686
return "CPU";
689687

@@ -702,7 +700,7 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
702700
}
703701

704702
static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
705-
free(buffer->context);
703+
ggml_aligned_free(buffer->context, buffer->size);
706704
}
707705

708706
static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
@@ -770,14 +768,19 @@ static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_ty
770768
}
771769

772770
static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
773-
size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
774-
void * data = malloc(size); // TODO: use GGML_ALIGNED_MALLOC (move to ggml-impl.h)
771+
auto alloc_size = size;
772+
if (alloc_size == 0) {
773+
alloc_size = 1;
774+
}
775+
776+
void * data = ggml_aligned_malloc(alloc_size);
777+
775778
if (data == NULL) {
776-
GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size);
779+
GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, alloc_size);
777780
return NULL;
778781
}
779782

780-
return ggml_backend_buffer_init(buft, ggml_backend_cpu_buffer_i, data, size);
783+
return ggml_backend_buffer_init(buft, ggml_backend_cpu_buffer_i, data, alloc_size);
781784
}
782785

783786
static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {

ggml/src/ggml-impl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ extern "C" {
1919
#define MIN(a, b) ((a) < (b) ? (a) : (b))
2020
#define MAX(a, b) ((a) > (b) ? (a) : (b))
2121

22+
// required for mmap as gguf only guarantees 32-byte alignment
23+
#define TENSOR_ALIGNMENT 32
24+
2225
// static_assert should be a #define, but if it's not,
2326
// fall back to the _Static_assert C11 keyword.
2427
// if C99 - static_assert is noop
@@ -196,6 +199,11 @@ struct ggml_cgraph {
196199

197200
struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
198201

202+
// Memory allocation
203+
204+
void * ggml_aligned_malloc(size_t size);
205+
void ggml_aligned_free(void * ptr, size_t size);
206+
199207
#ifdef __cplusplus
200208
}
201209
#endif

ggml/src/ggml.c

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@
3535
#include <omp.h>
3636
#endif
3737

38-
#ifdef GGML_USE_METAL
39-
#include <unistd.h>
40-
#endif
41-
4238
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
4339
#undef GGML_USE_LLAMAFILE
4440
#endif
@@ -189,6 +185,8 @@ typedef pthread_t ggml_thread_t;
189185
#endif
190186

191187
#if defined(__APPLE__)
188+
#include <unistd.h>
189+
#include <mach/mach.h>
192190
#include <TargetConditionals.h>
193191
#endif
194192

@@ -386,22 +384,40 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
386384
//#define GGML_SOFT_MAX_ACCELERATE
387385
#endif
388386

387+
388+
void * ggml_aligned_malloc(size_t size) {
389389
#if defined(_MSC_VER) || defined(__MINGW32__)
390-
#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
391-
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
390+
return _aligned_malloc(size, TENSOR_ALIGNMENT);
392391
#else
393-
inline static void * ggml_aligned_malloc(size_t size) {
394392
if (size == 0) {
395393
GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
396394
return NULL;
397395
}
398396
void * aligned_memory = NULL;
399397
#ifdef GGML_USE_CPU_HBM
400-
int result = hbw_posix_memalign(&aligned_memory, 16, size);
398+
int result = hbw_posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size);
399+
#elif TARGET_OS_OSX
400+
kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE);
401+
int result = EFAULT;
402+
switch (alloc_status) {
403+
case KERN_SUCCESS:
404+
result = 0;
405+
break;
406+
case KERN_INVALID_ADDRESS:
407+
result = EINVAL;
408+
break;
409+
case KERN_NO_SPACE:
410+
result = ENOMEM;
411+
break;
412+
default:
413+
result = EFAULT;
414+
break;
415+
}
401416
#elif GGML_USE_METAL
402-
int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size);
417+
const long page_size = sysconf(_SC_PAGESIZE);
418+
int result = posix_memalign(&aligned_memory, MAX(TENSOR_ALIGNMENT, page_size), size);
403419
#else
404-
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
420+
int result = posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size);
405421
#endif
406422
if (result != 0) {
407423
// Handle allocation failure
@@ -419,14 +435,26 @@ inline static void * ggml_aligned_malloc(size_t size) {
419435
return NULL;
420436
}
421437
return aligned_memory;
438+
#endif
422439
}
423-
#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
424-
#ifdef GGML_USE_CPU_HBM
425-
#define GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr)
440+
441+
void ggml_aligned_free(void * ptr, size_t size) {
442+
GGML_UNUSED(size);
443+
#if defined(_MSC_VER) || defined(__MINGW32__)
444+
_aligned_free(ptr);
445+
#elif GGML_USE_CPU_HBM
446+
if (ptr != NULL) {
447+
hbw_free(ptr);
448+
}
449+
#elif TARGET_OS_OSX
450+
if (ptr != NULL) {
451+
vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ptr, size);
452+
}
426453
#else
427-
#define GGML_ALIGNED_FREE(ptr) free(ptr)
428-
#endif
454+
free(ptr);
429455
#endif
456+
}
457+
430458

431459
inline static void * ggml_malloc(size_t size) {
432460
if (size == 0) {
@@ -3869,7 +3897,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
38693897

38703898
*ctx = (struct ggml_context) {
38713899
/*.mem_size =*/ mem_size,
3872-
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
3900+
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : ggml_aligned_malloc(mem_size),
38733901
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
38743902
/*.no_alloc =*/ params.no_alloc,
38753903
/*.no_alloc_save =*/ params.no_alloc,
@@ -3909,7 +3937,7 @@ void ggml_free(struct ggml_context * ctx) {
39093937
__func__, i, ggml_used_mem(ctx));
39103938

39113939
if (ctx->mem_buffer_owned) {
3912-
GGML_ALIGNED_FREE(ctx->mem_buffer);
3940+
ggml_aligned_free(ctx->mem_buffer, ctx->mem_size);
39133941
}
39143942

39153943
found = true;
@@ -19608,9 +19636,10 @@ static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask
1960819636
void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
1960919637
if (!threadpool) return;
1961019638

19639+
const int n_threads = threadpool->n_threads_max;
19640+
1961119641
#ifndef GGML_USE_OPENMP
1961219642
struct ggml_compute_state* workers = threadpool->workers;
19613-
const int n_threads = threadpool->n_threads_max;
1961419643

1961519644
ggml_mutex_lock(&threadpool->mutex);
1961619645

@@ -19630,8 +19659,9 @@ void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
1963019659
ggml_cond_destroy(&threadpool->cond);
1963119660
#endif // GGML_USE_OPENMP
1963219661

19633-
GGML_ALIGNED_FREE(threadpool->workers);
19634-
GGML_ALIGNED_FREE(threadpool);
19662+
const size_t workers_size = sizeof(struct ggml_compute_state) * n_threads;
19663+
ggml_aligned_free(threadpool->workers, workers_size);
19664+
ggml_aligned_free(threadpool, sizeof(struct ggml_threadpool));
1963519665
}
1963619666

1963719667
#ifndef GGML_USE_OPENMP
@@ -20063,7 +20093,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
2006320093
struct ggml_cplan * cplan) {
2006420094

2006520095
struct ggml_threadpool * threadpool =
20066-
GGML_ALIGNED_MALLOC(sizeof(struct ggml_threadpool));
20096+
ggml_aligned_malloc(sizeof(struct ggml_threadpool));
2006720097
{
2006820098
threadpool->cgraph = cgraph;
2006920099
threadpool->cplan = cplan;
@@ -20084,7 +20114,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
2008420114

2008520115
// Allocate and init workers state
2008620116
const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads;
20087-
struct ggml_compute_state * workers = GGML_ALIGNED_MALLOC(workers_size);
20117+
struct ggml_compute_state * workers = ggml_aligned_malloc(workers_size);
2008820118

2008920119
memset(workers, 0, workers_size);
2009020120
for (int j = 0; j < tpp->n_threads; j++) {

src/llama-vocab.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ struct llm_tokenizer_spm_session {
221221
}
222222

223223
// seed the work queue with all possible 2-character tokens.
224-
for (size_t i = 1; i < symbols.size(); ++i) {
224+
for (int i = 1; i < (int) symbols.size(); ++i) {
225225
try_add_bigram(i - 1, i);
226226
}
227227

@@ -563,7 +563,7 @@ struct llm_tokenizer_bpe_session {
563563
index++;
564564
symbols.emplace_back(sym);
565565
}
566-
for (size_t i = 1; i < symbols.size(); ++i) {
566+
for (int i = 1; i < (int) symbols.size(); ++i) {
567567
add_new_bigram(i - 1, i);
568568
}
569569

tests/test-json-schema-to-grammar.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
696696
"pattern": "^abc?d*efg+(hij)?kl$"
697697
})""",
698698
R"""(
699-
root ::= "\"" "ab" "c"? "d"* "ef" "g"+ ("hij")? "kl" "\"" space
699+
root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\"" space
700700
space ::= | " " | "\n" [ \t]{0,20}
701701
)"""
702702
});
@@ -709,7 +709,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
709709
"pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
710710
})""",
711711
R"""(
712-
root ::= "\"" "[]{}()|+*?" "\"" space
712+
root ::= "\"" ("[]{}()|+*?") "\"" space
713713
space ::= | " " | "\n" [ \t]{0,20}
714714
)"""
715715
});
@@ -722,7 +722,20 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
722722
"pattern": "^\"$"
723723
})""",
724724
R"""(
725-
root ::= "\"" "\"" "\"" space
725+
root ::= "\"" ("\"") "\"" space
726+
space ::= | " " | "\n" [ \t]{0,20}
727+
)"""
728+
});
729+
730+
test({
731+
SUCCESS,
732+
"regexp with top-level alternation",
733+
R"""({
734+
"type": "string",
735+
"pattern": "^A|B|C|D$"
736+
})""",
737+
R"""(
738+
root ::= "\"" ("A" | "B" | "C" | "D") "\"" space
726739
space ::= | " " | "\n" [ \t]{0,20}
727740
)"""
728741
});
@@ -736,7 +749,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
736749
})""",
737750
R"""(
738751
dot ::= [^\x0A\x0D]
739-
root ::= "\"" ("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot "\"" space
752+
root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\"" space
740753
root-1 ::= [0-9]
741754
space ::= | " " | "\n" [ \t]{0,20}
742755
)"""

0 commit comments

Comments
 (0)