Skip to content

Commit 3d83305

Browse files
committed
Sync aLoRA Support
1 parent 9ae9543 commit 3d83305

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

llama_cpp/llama_cpp.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,6 +1836,27 @@ def llama_adapter_lora_free(adapter: llama_adapter_lora_p, /):
18361836
...
18371837

18381838

1839+
# // Get the invocation tokens if the current lora is an alora
1840+
# LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter);
1841+
@ctypes_function(
1842+
"llama_adapter_get_alora_n_invocation_tokens",
1843+
[llama_adapter_lora_p_ctypes],
1844+
ctypes.c_uint64,
1845+
)
1846+
def llama_adapter_get_alora_n_invocation_tokens(adapter: llama_adapter_lora_p, /) -> ctypes.c_uint64:
1847+
...
1848+
1849+
1850+
# LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens (const struct llama_adapter_lora * adapter);
1851+
@ctypes_function(
1852+
"llama_adapter_get_alora_invocation_tokens",
1853+
[llama_adapter_lora_p_ctypes],
1854+
ctypes.c_uint64,
1855+
)
1856+
def llama_adapter_get_alora_invocation_tokens(adapter: llama_adapter_lora_p, /) -> llama_token_p:
1857+
...
1858+
1859+
18391860
# // The following functions operate on a llama_context, hence the naming: llama_verb_...
18401861

18411862

@@ -3380,7 +3401,7 @@ def llama_token_to_piece(
33803401
"llama_detokenize",
33813402
[
33823403
llama_model_p_ctypes,
3383-
ctypes.POINTER(llama_token),
3404+
llama_token_p,
33843405
ctypes.c_int32,
33853406
ctypes.c_char_p,
33863407
ctypes.c_int32,
@@ -3907,7 +3928,7 @@ def llama_sampler_init_grammar(
39073928
ctypes.c_char_p,
39083929
ctypes.POINTER(ctypes.c_char_p),
39093930
ctypes.c_size_t,
3910-
ctypes.POINTER(llama_token),
3931+
llama_token_p,
39113932
ctypes.c_size_t,
39123933
],
39133934
llama_sampler_p_ctypes,

tests/test_llama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def test_real_model(llama_cpp_model_path):
8686
cparams.n_ubatch = 16
8787
cparams.n_threads = multiprocessing.cpu_count()
8888
cparams.n_threads_batch = multiprocessing.cpu_count()
89-
cparams.flash_attn = True
9089
cparams.swa_full = True
9190
cparams.kv_unified = True
9291

0 commit comments

Comments
 (0)