Skip to content

Commit 0523859

Browse files
committed
class LlamaSampler: append add_dry()
Fix the char array params convert problem
1 parent 0874bac commit 0523859

File tree

2 files changed

+63
-8
lines changed

2 files changed

+63
-8
lines changed

llama_cpp/_internals.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -821,20 +821,51 @@ def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
821821
)
822822
self._add_sampler(sampler)
823823

824+
def convert_list_str_to_char_ptr_array(str_list: List[str]) -> ctypes.POINTER(ctypes.POINTER(ctypes.c_char)):
825+
"""
826+
Converts a list of strings to a char** array for C interop.
827+
Args:
828+
list[str]: List of string objects.
829+
Returns:
830+
A ctypes pointer to a char** array.
831+
"""
832+
# Encode strings to bytes
833+
byte_list = [s.encode('utf-8') for s in str_list]
834+
# Calculate the number of breakers
835+
num_byte_list= len(byte_list)
836+
# Define the type of a char pointer
837+
char_ptr_type = ctypes.POINTER(ctypes.c_char)
838+
# Define the type of an array of char pointers
839+
char_ptr_array_type = char_ptr_type * num_byte_list
840+
841+
# Allocate memory for the array of char pointers
842+
char_ptr_array = char_ptr_array_type()
843+
844+
# Populate the array with pointers to the byte strings
845+
for i, byte_string in enumerate(byte_list):
846+
# Create a null-terminated C-style string buffer
847+
c_char_array = ctypes.create_string_buffer(byte_string)
848+
# Cast the buffer to a char pointer and assign it to the array
849+
char_ptr_array[i] = ctypes.cast(c_char_array, char_ptr_type)
850+
851+
# Cast the array to a char** pointer and return it
852+
return ctypes.cast(char_ptr_array, ctypes.POINTER(char_ptr_type)), num_byte_list
853+
824854
def add_grammar_lazy(
825855
self,
826856
model: LlamaModel,
827857
grammar: LlamaGrammar,
828-
trigger_words: list[bytes],
829-
num_trigger_words: int,
830858
trigger_tokens:list[llama_cpp.llama_token],
831-
num_trigger_tokens: int
859+
num_trigger_tokens: int,
860+
trigger_words: list[str]=[]
832861
):
862+
trigger_words_char_ptr_array, num_trigger_words = self.convert_list_str_to_char_ptr_array(trigger_words)
863+
833864
sampler = llama_cpp.llama_sampler_init_grammar_lazy(
834865
model.vocab,
835866
grammar._grammar.encode("utf-8"),
836867
grammar._root.encode("utf-8"),
837-
trigger_words,
868+
trigger_words_char_ptr_array,
838869
num_trigger_words,
839870
trigger_tokens,
840871
num_trigger_tokens
@@ -845,16 +876,17 @@ def add_grammar_lazy_patterns(
845876
self,
846877
model: LlamaModel,
847878
grammar: LlamaGrammar,
848-
trigger_patterns: list[bytes],
849879
num_trigger_patterns: int,
850880
trigger_tokens:list[llama_cpp.llama_token],
851-
num_trigger_tokens: int
881+
num_trigger_tokens: int,
882+
trigger_patterns: list[str]=[]
852883
):
884+
trigger_patterns_char_ptr_array, num_trigger_patterns = self.convert_list_str_to_char_ptr_array(trigger_patterns)
853885
sampler = llama_cpp.llama_sampler_init_grammar_lazy_patterns(
854886
model.vocab,
855887
grammar._grammar.encode("utf-8"),
856888
grammar._root.encode("utf-8"),
857-
trigger_patterns,
889+
trigger_patterns_char_ptr_array,
858890
num_trigger_patterns,
859891
trigger_tokens,
860892
num_trigger_tokens
@@ -882,6 +914,29 @@ def add_penalties(
882914
)
883915
self._add_sampler(sampler)
884916

917+
def add_dry(
918+
self,
919+
model: LlamaModel,
920+
n_ctx_train: int,
921+
dry_multiplier: float,
922+
dry_base: float,
923+
dry_allowed_length: int,
924+
dry_penalty_last_n: int,
925+
seq_breakers: list[str] = []
926+
):
927+
seq_breakers_bytes_char_ptr_array, num_breakers = self.convert_list_str_to_char_ptr_array(seq_breakers)
928+
sampler = llama_cpp.llama_sampler_init_dry(
929+
model.vocab,
930+
n_ctx_train,
931+
dry_multiplier,
932+
dry_base,
933+
dry_allowed_length,
934+
dry_penalty_last_n,
935+
seq_breakers_bytes_char_ptr_array,
936+
num_breakers
937+
)
938+
self._add_sampler(sampler)
939+
885940
def init_logit_bias(
886941
self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p
887942
):

llama_cpp/llama_cpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3413,7 +3413,7 @@ class llama_sampler(ctypes.Structure):
34133413
("clone", llama_sampler_i_clone),
34143414
("free", llama_sampler_i_free),
34153415
]
3416-
llama_sampler_i_p = CtypesPointer[llama_sampler_i]
3416+
llama_sampler_i_p = ctypes.POINTER(llama_sampler_i)
34173417

34183418
# // mirror of llama_sampler_i:
34193419

0 commit comments

Comments
 (0)