Skip to content

Commit cb49f02

Browse files
committed
Append xtc params to sampler_init
1 parent 0523859 commit cb49f02

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

llama_cpp/llama.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,8 @@ def _init_sampler(
680680
mirostat_mode: int = 0,
681681
mirostat_eta: float = 0.1,
682682
mirostat_tau: float = 5.0,
683+
xtc_threshold: float = 0.1,
684+
xtc_probability: float = 0.0,
683685
penalize_nl: bool = True,
684686
logits_processor: Optional[LogitsProcessorList] = None,
685687
grammar: Optional[LlamaGrammar] = None,
@@ -753,6 +755,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
753755
sampler.add_min_p(min_p, min_keep)
754756
sampler.add_temp(temp)
755757
sampler.add_dist(self._seed)
758+
sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed)
756759
return sampler
757760

758761
def sample(
@@ -769,6 +772,8 @@ def sample(
769772
mirostat_mode: int = 0,
770773
mirostat_eta: float = 0.1,
771774
mirostat_tau: float = 5.0,
775+
xtc_threshold: float = 0.1,
776+
xtc_probability: float = 0.0,
772777
penalize_nl: bool = True,
773778
logits_processor: Optional[LogitsProcessorList] = None,
774779
grammar: Optional[LlamaGrammar] = None,
@@ -804,6 +809,8 @@ def sample(
804809
mirostat_mode=mirostat_mode,
805810
mirostat_tau=mirostat_tau,
806811
mirostat_eta=mirostat_eta,
812+
xtc_threshold = xtc_threshold,
813+
xtc_probability = xtc_probability,
807814
penalize_nl=penalize_nl,
808815
logits_processor=logits_processor,
809816
grammar=grammar,
@@ -833,6 +840,8 @@ def generate(
833840
mirostat_mode: int = 0,
834841
mirostat_tau: float = 5.0,
835842
mirostat_eta: float = 0.1,
843+
xtc_threshold: float = 0.1,
844+
xtc_probability: float = 0.0,
836845
penalize_nl: bool = True,
837846
logits_processor: Optional[LogitsProcessorList] = None,
838847
stopping_criteria: Optional[StoppingCriteriaList] = None,
@@ -872,6 +881,8 @@ def generate(
872881
mirostat_mode=mirostat_mode,
873882
mirostat_tau=mirostat_tau,
874883
mirostat_eta=mirostat_eta,
884+
xtc_threshold = xtc_threshold,
885+
xtc_probability = xtc_probability,
875886
penalize_nl=penalize_nl,
876887
logits_processor=logits_processor,
877888
grammar=grammar,
@@ -924,6 +935,8 @@ def generate(
924935
mirostat_mode=mirostat_mode,
925936
mirostat_tau=mirostat_tau,
926937
mirostat_eta=mirostat_eta,
938+
xtc_threshold = xtc_threshold,
939+
xtc_probability = xtc_probability,
927940
logits_processor=logits_processor,
928941
grammar=grammar,
929942
penalize_nl=penalize_nl,
@@ -1140,6 +1153,8 @@ def _create_completion(
11401153
mirostat_mode: int = 0,
11411154
mirostat_tau: float = 5.0,
11421155
mirostat_eta: float = 0.1,
1156+
xtc_threshold: float = 0.1,
1157+
xtc_probability: float = 0.0,
11431158
model: Optional[str] = None,
11441159
stopping_criteria: Optional[StoppingCriteriaList] = None,
11451160
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1328,6 +1343,8 @@ def logit_bias_processor(
13281343
mirostat_mode=mirostat_mode,
13291344
mirostat_tau=mirostat_tau,
13301345
mirostat_eta=mirostat_eta,
1346+
xtc_threshold=xtc_threshold,
1347+
xtc_probability=xtc_probability,
13311348
frequency_penalty=frequency_penalty,
13321349
presence_penalty=presence_penalty,
13331350
repeat_penalty=repeat_penalty,
@@ -1760,6 +1777,8 @@ def create_completion(
17601777
mirostat_mode: int = 0,
17611778
mirostat_tau: float = 5.0,
17621779
mirostat_eta: float = 0.1,
1780+
xtc_threshold: float = 0.1,
1781+
xtc_probability: float = 0.0,
17631782
model: Optional[str] = None,
17641783
stopping_criteria: Optional[StoppingCriteriaList] = None,
17651784
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1789,6 +1808,8 @@ def create_completion(
17891808
mirostat_mode: The mirostat sampling mode.
17901809
mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
17911810
mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
1811+
xtc-probability: Sets the chance for token removal (checked once on sampler start) (default: 0.0).
1812+
xtc-threshold: Sets a minimum probability threshold for tokens to be removed (default: 0.1).
17921813
model: The name to use for the model in the completion object.
17931814
stopping_criteria: A list of stopping criteria to use.
17941815
logits_processor: A list of logits processors to use.
@@ -1823,6 +1844,8 @@ def create_completion(
18231844
mirostat_mode=mirostat_mode,
18241845
mirostat_tau=mirostat_tau,
18251846
mirostat_eta=mirostat_eta,
1847+
xtc_threshold=xtc_threshold,
1848+
xtc_probability=xtc_probability,
18261849
model=model,
18271850
stopping_criteria=stopping_criteria,
18281851
logits_processor=logits_processor,
@@ -1857,6 +1880,8 @@ def __call__(
18571880
mirostat_mode: int = 0,
18581881
mirostat_tau: float = 5.0,
18591882
mirostat_eta: float = 0.1,
1883+
xtc_threshold: float = 0.1,
1884+
xtc_probability: float = 0.0,
18601885
model: Optional[str] = None,
18611886
stopping_criteria: Optional[StoppingCriteriaList] = None,
18621887
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1886,6 +1911,8 @@ def __call__(
18861911
mirostat_mode: The mirostat sampling mode.
18871912
mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
18881913
mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
1914+
xtc-probability: Sets the chance for token removal (checked once on sampler start) (default: 0.0). XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
1915+
xtc-threshold: Sets a minimum probability threshold for tokens to be removed (default: 0.1). XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
18891916
model: The name to use for the model in the completion object.
18901917
stopping_criteria: A list of stopping criteria to use.
18911918
logits_processor: A list of logits processors to use.
@@ -1920,6 +1947,8 @@ def __call__(
19201947
mirostat_mode=mirostat_mode,
19211948
mirostat_tau=mirostat_tau,
19221949
mirostat_eta=mirostat_eta,
1950+
xtc_threshold=xtc_threshold,
1951+
xtc_probability=xtc_probability,
19231952
model=model,
19241953
stopping_criteria=stopping_criteria,
19251954
logits_processor=logits_processor,
@@ -1951,6 +1980,8 @@ def create_chat_completion(
19511980
mirostat_mode: int = 0,
19521981
mirostat_tau: float = 5.0,
19531982
mirostat_eta: float = 0.1,
1983+
xtc_threshold: float = 0.1,
1984+
xtc_probability: float = 0.0,
19541985
model: Optional[str] = None,
19551986
logits_processor: Optional[LogitsProcessorList] = None,
19561987
grammar: Optional[LlamaGrammar] = None,
@@ -1985,6 +2016,8 @@ def create_chat_completion(
19852016
mirostat_mode: The mirostat sampling mode.
19862017
mirostat_tau: The mirostat sampling tau parameter.
19872018
mirostat_eta: The mirostat sampling eta parameter.
2019+
xtc-probability: Sets the chance for token removal (checked once on sampler start) (default: 0.0). XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
2020+
xtc-threshold: Sets a minimum probability threshold for tokens to be removed (default: 0.1).XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
19882021
model: The name to use for the model in the completion object.
19892022
logits_processor: A list of logits processors to use.
19902023
grammar: A grammar to use.
@@ -2024,6 +2057,8 @@ def create_chat_completion(
20242057
mirostat_mode=mirostat_mode,
20252058
mirostat_tau=mirostat_tau,
20262059
mirostat_eta=mirostat_eta,
2060+
xtc_threshold=xtc_threshold,
2061+
xtc_probability=xtc_probability,
20272062
model=model,
20282063
logits_processor=logits_processor,
20292064
grammar=grammar,

0 commit comments

Comments
 (0)