@@ -693,32 +693,14 @@ def _init_sampler(
693693 dry_penalty_last_n :int = 0 ,
694694 dry_seq_breakers : list [str ] = ["\n " , ":" , "\" " , "*" ],
695695 penalize_nl : bool = True ,
696+ logit_bias : Optional [Dict [int , float ]] = None ,
696697 logits_processor : Optional [LogitsProcessorList ] = None ,
697698 grammar : Optional [LlamaGrammar ] = None ,
698699 ):
699700 sampler = internals .LlamaSampler ()
700701
701- if logits_processor is not None :
702- # Create and add a custom sampler
703- def apply_func (token_data_array : llama_cpp .llama_token_data_array_p ):
704- size = token_data_array .contents .size
705- data_soa = token_data_array .contents .data
706- data_soa_address = ctypes .addressof (data_soa .contents )
707- # NOTE: This is probably broken
708- recarray = np .recarray (
709- shape = (size ,),
710- dtype = np .dtype (
711- [("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )],
712- align = True ,
713- ),
714- buf = (llama_cpp .llama_token_data * size ).from_address (
715- data_soa_address
716- ),
717- )
718- for logit_processor in logits_processor :
719- recarray .logit [:] = logit_processor (self ._input_ids , recarray .logit )
720-
721- sampler .add_custom (apply_func )
702+ if logit_bias is not None :
703+ sampler .add_logit_bias (self .n_vocab (), logit_bias )
722704
723705 sampler .add_penalties (
724706 n_vocab = self ._n_vocab ,
@@ -792,6 +774,7 @@ def sample(
792774 dry_penalty_last_n :int = 0 ,
793775 dry_seq_breakers : list [str ] = ["\n " , ":" , "\" " , "*" ],
794776 penalize_nl : bool = True ,
777+ logit_bias : Optional [Dict [int , float ]] = None ,
795778 logits_processor : Optional [LogitsProcessorList ] = None ,
796779 grammar : Optional [LlamaGrammar ] = None ,
797780 idx : Optional [int ] = None ,
@@ -834,6 +817,7 @@ def sample(
834817 dry_penalty_last_n = dry_penalty_last_n ,
835818 dry_seq_breakers = dry_seq_breakers ,
836819 penalize_nl = penalize_nl ,
820+ logit_bias = logit_bias ,
837821 logits_processor = logits_processor ,
838822 grammar = grammar ,
839823 )
@@ -870,6 +854,7 @@ def generate(
870854 dry_penalty_last_n :int = 0 ,
871855 dry_seq_breakers : list [str ] = ["\n " , ":" , "\" " , "*" ],
872856 penalize_nl : bool = True ,
857+ logit_bias : Optional [Dict [int , float ]] = None ,
873858 logits_processor : Optional [LogitsProcessorList ] = None ,
874859 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
875860 grammar : Optional [LlamaGrammar ] = None ,
@@ -916,6 +901,7 @@ def generate(
916901 dry_penalty_last_n = dry_penalty_last_n ,
917902 dry_seq_breakers = dry_seq_breakers ,
918903 penalize_nl = penalize_nl ,
904+ logit_bias = logit_bias ,
919905 logits_processor = logits_processor ,
920906 grammar = grammar ,
921907 )
@@ -974,6 +960,7 @@ def generate(
974960 dry_allowed_length = dry_allowed_length ,
975961 dry_penalty_last_n = dry_penalty_last_n ,
976962 dry_seq_breakers = dry_seq_breakers ,
963+ logit_bias = logit_bias ,
977964 logits_processor = logits_processor ,
978965 grammar = grammar ,
979966 penalize_nl = penalize_nl ,
@@ -1199,9 +1186,9 @@ def _create_completion(
11991186 dry_seq_breakers : list [str ] = ["\n " , ":" , "\" " , "*" ],
12001187 model : Optional [str ] = None ,
12011188 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1189+ logit_bias : Optional [Dict [int , float ]] = None ,
12021190 logits_processor : Optional [LogitsProcessorList ] = None ,
12031191 grammar : Optional [LlamaGrammar ] = None ,
1204- logit_bias : Optional [Dict [int , float ]] = None ,
12051192 ) -> Union [
12061193 Iterator [CreateCompletionResponse ], Iterator [CreateCompletionStreamResponse ]
12071194 ]:
@@ -1396,6 +1383,7 @@ def logit_bias_processor(
13961383 presence_penalty = presence_penalty ,
13971384 repeat_penalty = repeat_penalty ,
13981385 stopping_criteria = stopping_criteria ,
1386+ logit_bias = logit_bias ,
13991387 logits_processor = logits_processor ,
14001388 grammar = grammar ,
14011389 ):
@@ -1833,9 +1821,9 @@ def create_completion(
18331821 dry_seq_breakers : list [str ] = ["\n " , ":" , "\" " , "*" ],
18341822 model : Optional [str ] = None ,
18351823 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1824+ logit_bias : Optional [Dict [int , float ]] = None ,
18361825 logits_processor : Optional [LogitsProcessorList ] = None ,
18371826 grammar : Optional [LlamaGrammar ] = None ,
1838- logit_bias : Optional [Dict [int , float ]] = None ,
18391827 ) -> Union [CreateCompletionResponse , Iterator [CreateCompletionStreamResponse ]]:
18401828 """Generate text from a prompt.
18411829
@@ -1869,9 +1857,9 @@ def create_completion(
18691857 dry_seq_breakers: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n ', ':', '"', '*']`
18701858 model: The name to use for the model in the completion object.
18711859 stopping_criteria: A list of stopping criteria to use.
1860+ logit_bias: A logit bias to use.
18721861 logits_processor: A list of logits processors to use.
18731862 grammar: A grammar to use for constrained sampling.
1874- logit_bias: A logit bias to use.
18751863
18761864 Raises:
18771865 ValueError: If the requested tokens exceed the context window.
@@ -1910,9 +1898,9 @@ def create_completion(
19101898 dry_seq_breakers = dry_seq_breakers ,
19111899 model = model ,
19121900 stopping_criteria = stopping_criteria ,
1901+ logit_bias = logit_bias ,
19131902 logits_processor = logits_processor ,
19141903 grammar = grammar ,
1915- logit_bias = logit_bias ,
19161904 )
19171905 if stream :
19181906 chunks : Iterator [CreateCompletionStreamResponse ] = completion_or_chunks
@@ -1951,9 +1939,9 @@ def __call__(
19511939 dry_seq_breakers : list [str ] = ["\n " , ":" , "\" " , "*" ],
19521940 model : Optional [str ] = None ,
19531941 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1942+ logit_bias : Optional [Dict [int , float ]] = None ,
19541943 logits_processor : Optional [LogitsProcessorList ] = None ,
19551944 grammar : Optional [LlamaGrammar ] = None ,
1956- logit_bias : Optional [Dict [int , float ]] = None ,
19571945 ) -> Union [CreateCompletionResponse , Iterator [CreateCompletionStreamResponse ]]:
19581946 """Generate text from a prompt.
19591947
@@ -1987,9 +1975,9 @@ def __call__(
19871975 dry_seq_breakers: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n ', ':', '"', '*']`
19881976 model: The name to use for the model in the completion object.
19891977 stopping_criteria: A list of stopping criteria to use.
1978+ logit_bias: A logit bias to use.
19901979 logits_processor: A list of logits processors to use.
19911980 grammar: A grammar to use for constrained sampling.
1992- logit_bias: A logit bias to use.
19931981
19941982 Raises:
19951983 ValueError: If the requested tokens exceed the context window.
@@ -2028,9 +2016,9 @@ def __call__(
20282016 dry_seq_breakers = dry_seq_breakers ,
20292017 model = model ,
20302018 stopping_criteria = stopping_criteria ,
2019+ logit_bias = logit_bias ,
20312020 logits_processor = logits_processor ,
20322021 grammar = grammar ,
2033- logit_bias = logit_bias ,
20342022 )
20352023
20362024 def create_chat_completion (
@@ -2065,9 +2053,9 @@ def create_chat_completion(
20652053 dry_penalty_last_n :int = 0 ,
20662054 dry_seq_breakers : list [str ] = ["\n " , ":" , "\" " , "*" ],
20672055 model : Optional [str ] = None ,
2056+ logit_bias : Optional [Dict [int , float ]] = None ,
20682057 logits_processor : Optional [LogitsProcessorList ] = None ,
20692058 grammar : Optional [LlamaGrammar ] = None ,
2070- logit_bias : Optional [Dict [int , float ]] = None ,
20712059 logprobs : Optional [bool ] = None ,
20722060 top_logprobs : Optional [int ] = None ,
20732061 ) -> Union [
@@ -2106,9 +2094,9 @@ def create_chat_completion(
21062094 dry_penalty_last_n: How many tokens to scan for repetitions. Default: `0`, where `0` is disabled and `-1` is context size.
21072095 dry_seq_breakers: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n ', ':', '"', '*']`
21082096 model: The name to use for the model in the completion object.
2097+ logit_bias: A logit bias to use.
21092098 logits_processor: A list of logits processors to use.
21102099 grammar: A grammar to use.
2111- logit_bias: A logit bias to use.
21122100
21132101 Returns:
21142102 Generated chat completion or a stream of chat completion chunks.
@@ -2152,9 +2140,9 @@ def create_chat_completion(
21522140 dry_penalty_last_n = dry_penalty_last_n ,
21532141 dry_seq_breakers = dry_seq_breakers ,
21542142 model = model ,
2143+ logit_bias = logit_bias ,
21552144 logits_processor = logits_processor ,
21562145 grammar = grammar ,
2157- logit_bias = logit_bias ,
21582146 )
21592147
21602148 def create_chat_completion_openai_v1 (
0 commit comments