Skip to content

Commit 90ed7a6

Browse files
committed
Choose the logit_bias sampler instead of the potentially unsafe logits_processor
1 parent 004a579 commit 90ed7a6

File tree

2 files changed

+29
-36
lines changed

2 files changed

+29
-36
lines changed

llama_cpp/_internals.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -982,12 +982,17 @@ def add_dry(
982982
)
983983
self._add_sampler(sampler)
984984

985-
def init_logit_bias(
986-
self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p
985+
def add_logit_bias(
986+
self, n_vocab: int, logit_bias: Dict[int, float]
987987
):
988-
sampler = llama_cpp.llama_sampler_init_logit_bias(
989-
n_vocab, n_logit_bias, logit_bias
990-
)
988+
# Construct a C array to store the contents of the logit_bias dictionary
989+
logit_bias_array = (llama_cpp.llama_logit_bias * len(logit_bias))()
990+
991+
for i, (token, bias) in enumerate(logit_bias.items()):
992+
logit_bias_array[i].token = token
993+
logit_bias_array[i].bias = bias
994+
995+
sampler = llama_cpp.llama_sampler_init_logit_bias(n_vocab, len(logit_bias), logit_bias_array)
991996
self._add_sampler(sampler)
992997

993998
def add_custom(

llama_cpp/llama.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -693,32 +693,14 @@ def _init_sampler(
693693
dry_penalty_last_n:int = 0,
694694
dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"],
695695
penalize_nl: bool = True,
696+
logit_bias: Optional[Dict[int, float]] = None,
696697
logits_processor: Optional[LogitsProcessorList] = None,
697698
grammar: Optional[LlamaGrammar] = None,
698699
):
699700
sampler = internals.LlamaSampler()
700701

701-
if logits_processor is not None:
702-
# Create and add a custom sampler
703-
def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
704-
size = token_data_array.contents.size
705-
data_soa = token_data_array.contents.data
706-
data_soa_address = ctypes.addressof(data_soa.contents)
707-
# NOTE: This is probably broken
708-
recarray = np.recarray(
709-
shape=(size,),
710-
dtype=np.dtype(
711-
[("id", np.intc), ("logit", np.single), ("p", np.single)],
712-
align=True,
713-
),
714-
buf=(llama_cpp.llama_token_data * size).from_address(
715-
data_soa_address
716-
),
717-
)
718-
for logit_processor in logits_processor:
719-
recarray.logit[:] = logit_processor(self._input_ids, recarray.logit)
720-
721-
sampler.add_custom(apply_func)
702+
if logit_bias is not None:
703+
sampler.add_logit_bias(self.n_vocab(), logit_bias)
722704

723705
sampler.add_penalties(
724706
n_vocab=self._n_vocab,
@@ -792,6 +774,7 @@ def sample(
792774
dry_penalty_last_n:int = 0,
793775
dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"],
794776
penalize_nl: bool = True,
777+
logit_bias: Optional[Dict[int, float]] = None,
795778
logits_processor: Optional[LogitsProcessorList] = None,
796779
grammar: Optional[LlamaGrammar] = None,
797780
idx: Optional[int] = None,
@@ -834,6 +817,7 @@ def sample(
834817
dry_penalty_last_n=dry_penalty_last_n,
835818
dry_seq_breakers=dry_seq_breakers,
836819
penalize_nl=penalize_nl,
820+
logit_bias=logit_bias,
837821
logits_processor=logits_processor,
838822
grammar=grammar,
839823
)
@@ -870,6 +854,7 @@ def generate(
870854
dry_penalty_last_n:int = 0,
871855
dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"],
872856
penalize_nl: bool = True,
857+
logit_bias: Optional[Dict[int, float]] = None,
873858
logits_processor: Optional[LogitsProcessorList] = None,
874859
stopping_criteria: Optional[StoppingCriteriaList] = None,
875860
grammar: Optional[LlamaGrammar] = None,
@@ -916,6 +901,7 @@ def generate(
916901
dry_penalty_last_n=dry_penalty_last_n,
917902
dry_seq_breakers=dry_seq_breakers,
918903
penalize_nl=penalize_nl,
904+
logit_bias=logit_bias,
919905
logits_processor=logits_processor,
920906
grammar=grammar,
921907
)
@@ -974,6 +960,7 @@ def generate(
974960
dry_allowed_length=dry_allowed_length,
975961
dry_penalty_last_n=dry_penalty_last_n,
976962
dry_seq_breakers=dry_seq_breakers,
963+
logit_bias=logit_bias,
977964
logits_processor=logits_processor,
978965
grammar=grammar,
979966
penalize_nl=penalize_nl,
@@ -1199,9 +1186,9 @@ def _create_completion(
11991186
dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"],
12001187
model: Optional[str] = None,
12011188
stopping_criteria: Optional[StoppingCriteriaList] = None,
1189+
logit_bias: Optional[Dict[int, float]] = None,
12021190
logits_processor: Optional[LogitsProcessorList] = None,
12031191
grammar: Optional[LlamaGrammar] = None,
1204-
logit_bias: Optional[Dict[int, float]] = None,
12051192
) -> Union[
12061193
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
12071194
]:
@@ -1396,6 +1383,7 @@ def logit_bias_processor(
13961383
presence_penalty=presence_penalty,
13971384
repeat_penalty=repeat_penalty,
13981385
stopping_criteria=stopping_criteria,
1386+
logit_bias=logit_bias,
13991387
logits_processor=logits_processor,
14001388
grammar=grammar,
14011389
):
@@ -1833,9 +1821,9 @@ def create_completion(
18331821
dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"],
18341822
model: Optional[str] = None,
18351823
stopping_criteria: Optional[StoppingCriteriaList] = None,
1824+
logit_bias: Optional[Dict[int, float]] = None,
18361825
logits_processor: Optional[LogitsProcessorList] = None,
18371826
grammar: Optional[LlamaGrammar] = None,
1838-
logit_bias: Optional[Dict[int, float]] = None,
18391827
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
18401828
"""Generate text from a prompt.
18411829
@@ -1869,9 +1857,9 @@ def create_completion(
18691857
dry_seq_breakers: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']`
18701858
model: The name to use for the model in the completion object.
18711859
stopping_criteria: A list of stopping criteria to use.
1860+
logit_bias: A logit bias to use.
18721861
logits_processor: A list of logits processors to use.
18731862
grammar: A grammar to use for constrained sampling.
1874-
logit_bias: A logit bias to use.
18751863
18761864
Raises:
18771865
ValueError: If the requested tokens exceed the context window.
@@ -1910,9 +1898,9 @@ def create_completion(
19101898
dry_seq_breakers=dry_seq_breakers,
19111899
model=model,
19121900
stopping_criteria=stopping_criteria,
1901+
logit_bias=logit_bias,
19131902
logits_processor=logits_processor,
19141903
grammar=grammar,
1915-
logit_bias=logit_bias,
19161904
)
19171905
if stream:
19181906
chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks
@@ -1951,9 +1939,9 @@ def __call__(
19511939
dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"],
19521940
model: Optional[str] = None,
19531941
stopping_criteria: Optional[StoppingCriteriaList] = None,
1942+
logit_bias: Optional[Dict[int, float]] = None,
19541943
logits_processor: Optional[LogitsProcessorList] = None,
19551944
grammar: Optional[LlamaGrammar] = None,
1956-
logit_bias: Optional[Dict[int, float]] = None,
19571945
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
19581946
"""Generate text from a prompt.
19591947
@@ -1987,9 +1975,9 @@ def __call__(
19871975
dry_seq_breakers: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']`
19881976
model: The name to use for the model in the completion object.
19891977
stopping_criteria: A list of stopping criteria to use.
1978+
logit_bias: A logit bias to use.
19901979
logits_processor: A list of logits processors to use.
19911980
grammar: A grammar to use for constrained sampling.
1992-
logit_bias: A logit bias to use.
19931981
19941982
Raises:
19951983
ValueError: If the requested tokens exceed the context window.
@@ -2028,9 +2016,9 @@ def __call__(
20282016
dry_seq_breakers=dry_seq_breakers,
20292017
model=model,
20302018
stopping_criteria=stopping_criteria,
2019+
logit_bias=logit_bias,
20312020
logits_processor=logits_processor,
20322021
grammar=grammar,
2033-
logit_bias=logit_bias,
20342022
)
20352023

20362024
def create_chat_completion(
@@ -2065,9 +2053,9 @@ def create_chat_completion(
20652053
dry_penalty_last_n:int = 0,
20662054
dry_seq_breakers: list[str] = ["\n", ":", "\"", "*"],
20672055
model: Optional[str] = None,
2056+
logit_bias: Optional[Dict[int, float]] = None,
20682057
logits_processor: Optional[LogitsProcessorList] = None,
20692058
grammar: Optional[LlamaGrammar] = None,
2070-
logit_bias: Optional[Dict[int, float]] = None,
20712059
logprobs: Optional[bool] = None,
20722060
top_logprobs: Optional[int] = None,
20732061
) -> Union[
@@ -2106,9 +2094,9 @@ def create_chat_completion(
21062094
dry_penalty_last_n: How many tokens to scan for repetitions. Default: `0`, where `0` is disabled and `-1` is context size.
21072095
dry_seq_breakers: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']`
21082096
model: The name to use for the model in the completion object.
2097+
logit_bias: A logit bias to use.
21092098
logits_processor: A list of logits processors to use.
21102099
grammar: A grammar to use.
2111-
logit_bias: A logit bias to use.
21122100
21132101
Returns:
21142102
Generated chat completion or a stream of chat completion chunks.
@@ -2152,9 +2140,9 @@ def create_chat_completion(
21522140
dry_penalty_last_n=dry_penalty_last_n,
21532141
dry_seq_breakers=dry_seq_breakers,
21542142
model=model,
2143+
logit_bias=logit_bias,
21552144
logits_processor=logits_processor,
21562145
grammar=grammar,
2157-
logit_bias=logit_bias,
21582146
)
21592147

21602148
def create_chat_completion_openai_v1(

0 commit comments

Comments
 (0)