@@ -821,20 +821,51 @@ def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
821821 )
822822 self ._add_sampler (sampler )
823823
824+ def convert_list_str_to_char_ptr_array (str_list : List [str ]) -> ctypes .POINTER (ctypes .POINTER (ctypes .c_char )):
825+ """
826+ Converts a list of strings to a char** array for C interop.
827+ Args:
828+ list[str]: List of string objects.
829+ Returns:
830+ A ctypes pointer to a char** array.
831+ """
832+ # Encode strings to bytes
833+ byte_list = [s .encode ('utf-8' ) for s in str_list ]
834+ # Calculate the number of breakers
835+ num_byte_list = len (byte_list )
836+ # Define the type of a char pointer
837+ char_ptr_type = ctypes .POINTER (ctypes .c_char )
838+ # Define the type of an array of char pointers
839+ char_ptr_array_type = char_ptr_type * num_byte_list
840+
841+ # Allocate memory for the array of char pointers
842+ char_ptr_array = char_ptr_array_type ()
843+
844+ # Populate the array with pointers to the byte strings
845+ for i , byte_string in enumerate (byte_list ):
846+ # Create a null-terminated C-style string buffer
847+ c_char_array = ctypes .create_string_buffer (byte_string )
848+ # Cast the buffer to a char pointer and assign it to the array
849+ char_ptr_array [i ] = ctypes .cast (c_char_array , char_ptr_type )
850+
851+ # Cast the array to a char** pointer and return it
852+ return ctypes .cast (char_ptr_array , ctypes .POINTER (char_ptr_type )), num_byte_list
853+
824854 def add_grammar_lazy (
825855 self ,
826856 model : LlamaModel ,
827857 grammar : LlamaGrammar ,
828- trigger_words : list [bytes ],
829- num_trigger_words : int ,
830858 trigger_tokens :list [llama_cpp .llama_token ],
831- num_trigger_tokens : int
859+ num_trigger_tokens : int ,
860+ trigger_words : list [str ]= []
832861 ):
862+ trigger_words_char_ptr_array , num_trigger_words = self .convert_list_str_to_char_ptr_array (trigger_words )
863+
833864 sampler = llama_cpp .llama_sampler_init_grammar_lazy (
834865 model .vocab ,
835866 grammar ._grammar .encode ("utf-8" ),
836867 grammar ._root .encode ("utf-8" ),
837- trigger_words ,
868+ trigger_words_char_ptr_array ,
838869 num_trigger_words ,
839870 trigger_tokens ,
840871 num_trigger_tokens
@@ -845,16 +876,17 @@ def add_grammar_lazy_patterns(
845876 self ,
846877 model : LlamaModel ,
847878 grammar : LlamaGrammar ,
848- trigger_patterns : list [bytes ],
849879 num_trigger_patterns : int ,
850880 trigger_tokens :list [llama_cpp .llama_token ],
851- num_trigger_tokens : int
881+ num_trigger_tokens : int ,
882+ trigger_patterns : list [str ]= []
852883 ):
884+ trigger_patterns_char_ptr_array , num_trigger_patterns = self .convert_list_str_to_char_ptr_array (trigger_patterns )
853885 sampler = llama_cpp .llama_sampler_init_grammar_lazy_patterns (
854886 model .vocab ,
855887 grammar ._grammar .encode ("utf-8" ),
856888 grammar ._root .encode ("utf-8" ),
857- trigger_patterns ,
889+ trigger_patterns_char_ptr_array ,
858890 num_trigger_patterns ,
859891 trigger_tokens ,
860892 num_trigger_tokens
@@ -882,6 +914,29 @@ def add_penalties(
882914 )
883915 self ._add_sampler (sampler )
884916
917+ def add_dry (
918+ self ,
919+ model : LlamaModel ,
920+ n_ctx_train : int ,
921+ dry_multiplier : float ,
922+ dry_base : float ,
923+ dry_allowed_length : int ,
924+ dry_penalty_last_n : int ,
925+ seq_breakers : list [str ] = []
926+ ):
927+ seq_breakers_bytes_char_ptr_array , num_breakers = self .convert_list_str_to_char_ptr_array (seq_breakers )
928+ sampler = llama_cpp .llama_sampler_init_dry (
929+ model .vocab ,
930+ n_ctx_train ,
931+ dry_multiplier ,
932+ dry_base ,
933+ dry_allowed_length ,
934+ dry_penalty_last_n ,
935+ seq_breakers_bytes_char_ptr_array ,
936+ num_breakers
937+ )
938+ self ._add_sampler (sampler )
939+
885940 def init_logit_bias (
886941 self , n_vocab : int , n_logit_bias , logit_bias : llama_cpp .llama_logit_bias_p
887942 ):
0 commit comments