Skip to content

Commit 41cff76

Browse files
committed
Append top_n_sigma params to sampler_init
1 parent cb49f02 commit 41cff76

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

llama_cpp/llama.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,7 @@ def eval(self, tokens: Sequence[int]):
669669
def _init_sampler(
670670
self,
671671
top_k: int = 40,
672+
top_n_sigma: float = -1.00,
672673
top_p: float = 0.95,
673674
min_p: float = 0.05,
674675
typical_p: float = 1.0,
@@ -751,6 +752,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
751752
min_keep = max(1, n_probs)
752753
sampler.add_top_k(top_k)
753754
sampler.add_typical(typical_p, min_keep)
755+
sampler.add_top_n_sigma(top_n_sigma)
754756
sampler.add_top_p(top_p, min_keep)
755757
sampler.add_min_p(min_p, min_keep)
756758
sampler.add_temp(temp)
@@ -761,6 +763,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
761763
def sample(
762764
self,
763765
top_k: int = 40,
766+
top_n_sigma: float = -1.00,
764767
top_p: float = 0.95,
765768
min_p: float = 0.05,
766769
typical_p: float = 1.0,
@@ -798,6 +801,7 @@ def sample(
798801
tmp_sampler = True
799802
self._sampler = self._init_sampler(
800803
top_k=top_k,
804+
top_n_sigma=top_n_sigma,
801805
top_p=top_p,
802806
min_p=min_p,
803807
typical_p=typical_p,
@@ -828,6 +832,7 @@ def generate(
828832
self,
829833
tokens: Sequence[int],
830834
top_k: int = 40,
835+
top_n_sigma: float = -1.00,
831836
top_p: float = 0.95,
832837
min_p: float = 0.05,
833838
typical_p: float = 1.0,
@@ -870,6 +875,7 @@ def generate(
870875
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
871876
self._sampler = self._init_sampler(
872877
top_k=top_k,
878+
top_n_sigma=top_n_sigma,
873879
top_p=top_p,
874880
min_p=min_p,
875881
typical_p=typical_p,
@@ -924,6 +930,7 @@ def generate(
924930
while sample_idx < self.n_tokens:
925931
token = self.sample(
926932
top_k=top_k,
933+
top_n_sigma=top_n_sigma
927934
top_p=top_p,
928935
min_p=min_p,
929936
typical_p=typical_p,
@@ -1147,6 +1154,7 @@ def _create_completion(
11471154
presence_penalty: float = 0.0,
11481155
repeat_penalty: float = 1.0,
11491156
top_k: int = 40,
1157+
top_n_sigma: float = -1.00,
11501158
stream: bool = False,
11511159
seed: Optional[int] = None,
11521160
tfs_z: float = 1.0,
@@ -1335,6 +1343,7 @@ def logit_bias_processor(
13351343
for token in self.generate(
13361344
prompt_tokens,
13371345
top_k=top_k,
1346+
top_n_sigma=top_n_sigma,
13381347
top_p=top_p,
13391348
min_p=min_p,
13401349
typical_p=typical_p,
@@ -1771,6 +1780,7 @@ def create_completion(
17711780
presence_penalty: float = 0.0,
17721781
repeat_penalty: float = 1.0,
17731782
top_k: int = 40,
1783+
top_n_sigma: float = -1.00,
17741784
stream: bool = False,
17751785
seed: Optional[int] = None,
17761786
tfs_z: float = 1.0,
@@ -1802,14 +1812,15 @@ def create_completion(
18021812
presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
18031813
repeat_penalty: The penalty to apply to repeated tokens.
18041814
top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1815+
top_n_sigma: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1.00, -1.00 = disabled).
18051816
stream: Whether to stream the results.
18061817
seed: The seed to use for sampling.
18071818
tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
18081819
mirostat_mode: The mirostat sampling mode.
18091820
mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
18101821
mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
1811-
xtc-probability: Sets the chance for token removal (checked once on sampler start) (default: 0.0).
1812-
xtc-threshold: Sets a minimum probability threshold for tokens to be removed (default: 0.1).
1822+
xtc-probability: Sets the chance for token removal (checked once on sampler start) (default: 0.0). XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
1823+
xtc-threshold: Sets a minimum probability threshold for tokens to be removed (default: 0.1). XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
18131824
model: The name to use for the model in the completion object.
18141825
stopping_criteria: A list of stopping criteria to use.
18151826
logits_processor: A list of logits processors to use.
@@ -1838,6 +1849,7 @@ def create_completion(
18381849
presence_penalty=presence_penalty,
18391850
repeat_penalty=repeat_penalty,
18401851
top_k=top_k,
1852+
top_n_sigma=top_n_sigma,
18411853
stream=stream,
18421854
seed=seed,
18431855
tfs_z=tfs_z,
@@ -1874,6 +1886,7 @@ def __call__(
18741886
presence_penalty: float = 0.0,
18751887
repeat_penalty: float = 1.0,
18761888
top_k: int = 40,
1889+
top_n_sigma: float = -1.00,
18771890
stream: bool = False,
18781891
seed: Optional[int] = None,
18791892
tfs_z: float = 1.0,
@@ -1905,6 +1918,7 @@ def __call__(
19051918
presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
19061919
repeat_penalty: The penalty to apply to repeated tokens.
19071920
top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1921+
top_n_sigma: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1.00, -1.00 = disabled).
19081922
stream: Whether to stream the results.
19091923
seed: The seed to use for sampling.
19101924
tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
@@ -1941,6 +1955,7 @@ def __call__(
19411955
presence_penalty=presence_penalty,
19421956
repeat_penalty=repeat_penalty,
19431957
top_k=top_k,
1958+
top_n_sigma=top_n_sigma,
19441959
stream=stream,
19451960
seed=seed,
19461961
tfs_z=tfs_z,
@@ -1966,6 +1981,7 @@ def create_chat_completion(
19661981
temperature: float = 0.2,
19671982
top_p: float = 0.95,
19681983
top_k: int = 40,
1984+
top_n_sigma: float = -1.00,
19691985
min_p: float = 0.05,
19701986
typical_p: float = 1.0,
19711987
stream: bool = False,
@@ -2002,6 +2018,7 @@ def create_chat_completion(
20022018
temperature: The temperature to use for sampling.
20032019
top_p: The top-p value to use for nucleus sampling. Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
20042020
top_k: The top-k value to use for sampling. Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
2021+
top_n_sigma: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1.00, -1.00 = disabled).
20052022
min_p: The min-p value to use for minimum p sampling. Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841
20062023
typical_p: The typical-p value to use for sampling. Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
20072024
stream: Whether to stream the results.
@@ -2041,6 +2058,7 @@ def create_chat_completion(
20412058
temperature=temperature,
20422059
top_p=top_p,
20432060
top_k=top_k,
2061+
top_n_sigma=top_n_sigma,
20442062
min_p=min_p,
20452063
typical_p=typical_p,
20462064
logprobs=logprobs,
@@ -2208,6 +2226,10 @@ def n_embd(self) -> int:
22082226
"""Return the embedding size."""
22092227
return self._model.n_embd()
22102228

2229+
def n_head_kv(self) -> int:
2230+
"""Return the head_kv size."""
2231+
return self._model.n_head_kv()
2232+
22112233
def n_vocab(self) -> int:
22122234
"""Return the vocabulary size."""
22132235
return self._model.n_vocab()

0 commit comments

Comments
 (0)