@@ -268,14 +268,13 @@ def _sample(
268268 top_k : llama_cpp .c_int ,
269269 top_p : llama_cpp .c_float ,
270270 temp : llama_cpp .c_float ,
271- mirostat_mode : llama_cpp .c_int ,
272- mirostat_tau : llama_cpp .c_float ,
273- mirostat_eta : llama_cpp .c_float ,
274- mirostat_mu : llama_cpp .c_float ,
275- mirostat_m : llama_cpp .c_int ,
271+ tfs_z : llama_cpp .c_float ,
276272 repeat_penalty : llama_cpp .c_float ,
277273 frequency_penalty : llama_cpp .c_float ,
278274 presence_penalty : llama_cpp .c_float ,
275+ mirostat_mode : llama_cpp .c_int ,
276+ mirostat_tau : llama_cpp .c_float ,
277+ mirostat_eta : llama_cpp .c_float ,
279278 ):
280279 assert self .ctx is not None
281280 assert len (self .eval_logits ) > 0
@@ -305,45 +304,48 @@ def _sample(
305304 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
306305 penalty = repeat_penalty ,
307306 )
308- if mirostat_mode .value == 1 :
307+ llama_cpp .llama_sample_frequency_and_presence_penalties (
308+ ctx = self .ctx ,
309+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
310+ last_tokens_data = last_n_tokens_data ,
311+ last_tokens_size = last_n_tokens_size ,
312+ alpha_frequency = frequency_penalty ,
313+ alpha_presence = presence_penalty ,
314+ )
315+ if temp .value == 0.0 :
316+ return llama_cpp .llama_sample_token_greedy (
317+ ctx = self .ctx ,
318+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
319+ )
320+ elif mirostat_mode .value == 1 :
321+ mirostat_mu = llama_cpp .c_float (2.0 * mirostat_tau .value )
322+ mirostat_m = llama_cpp .c_int (100 )
309323 llama_cpp .llama_sample_temperature (
310324 ctx = self .ctx ,
311- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
325+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
312326 temp = temp ,
313327 )
314- llama_cpp .llama_sample_token_mirostat (
328+ return llama_cpp .llama_sample_token_mirostat (
315329 ctx = self .ctx ,
316- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
330+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
317331 tau = mirostat_tau ,
318332 eta = mirostat_eta ,
319- mu = llama_cpp .ctypes .byref (mirostat_mu ), # type: ignore
320- m = mirostat_m
333+ mu = llama_cpp .ctypes .byref (mirostat_mu ), # type: ignore
334+ m = mirostat_m ,
321335 )
322336 elif mirostat_mode .value == 2 :
337+ mirostat_mu = llama_cpp .c_float (2.0 * mirostat_tau .value )
323338 llama_cpp .llama_sample_temperature (
324339 ctx = self .ctx ,
325340 candidates = llama_cpp .ctypes .pointer (candidates ),
326341 temp = temp ,
327342 )
328- llama_cpp .llama_sample_token_mirostat_v2 (
343+ return llama_cpp .llama_sample_token_mirostat_v2 (
329344 ctx = self .ctx ,
330- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
345+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
331346 tau = mirostat_tau ,
332347 eta = mirostat_eta ,
333- mu = llama_cpp .ctypes .byref (mirostat_mu ) # type: ignore
334- )
335- llama_cpp .llama_sample_frequency_and_presence_penalties (
336- ctx = self .ctx ,
337- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
338- last_tokens_data = last_n_tokens_data ,
339- last_tokens_size = last_n_tokens_size ,
340- alpha_frequency = frequency_penalty ,
341- alpha_presence = presence_penalty ,
342- )
343- if float (temp .value ) == 0.0 :
344- return llama_cpp .llama_sample_token_greedy (
345- ctx = self .ctx ,
346- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
348+ mu = llama_cpp .ctypes .byref (mirostat_mu ), # type: ignore
347349 )
348350 else :
349351 llama_cpp .llama_sample_top_k (
@@ -355,7 +357,7 @@ def _sample(
355357 llama_cpp .llama_sample_tail_free (
356358 ctx = self .ctx ,
357359 candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
358- z = llama_cpp . c_float ( 1.0 ) ,
360+ z = tfs_z ,
359361 min_keep = llama_cpp .c_size_t (1 ),
360362 )
361363 llama_cpp .llama_sample_typical (
@@ -382,17 +384,16 @@ def _sample(
382384
383385 def sample (
384386 self ,
385- top_k : int ,
386- top_p : float ,
387- temp : float ,
388- mirostat_mode : int ,
389- mirostat_tau : float ,
390- mirostat_eta : float ,
391- mirostat_mu : float ,
392- mirostat_m : int ,
393- repeat_penalty : float ,
387+ top_k : int = 40 ,
388+ top_p : float = 0.95 ,
389+ temp : float = 0.80 ,
390+ repeat_penalty : float = 1.1 ,
394391 frequency_penalty : float = 0.0 ,
395392 presence_penalty : float = 0.0 ,
393+ tfs_z : float = 1.0 ,
394+ mirostat_mode : int = 0 ,
395+ mirostat_eta : float = 0.1 ,
396+ mirostat_tau : float = 5.0 ,
396397 ):
397398 """Sample a token from the model.
398399
@@ -417,14 +418,13 @@ def sample(
417418 top_k = llama_cpp .c_int (top_k ),
418419 top_p = llama_cpp .c_float (top_p ),
419420 temp = llama_cpp .c_float (temp ),
420- mirostat_mode = llama_cpp .c_int (mirostat_mode ),
421- mirostat_mu = llama_cpp .c_float (mirostat_mu ),
422- mirostat_tau = llama_cpp .c_float (mirostat_tau ),
423- mirostat_eta = llama_cpp .c_float (mirostat_eta ),
424- mirostat_m = llama_cpp .c_int (mirostat_m ),
421+ tfs_z = llama_cpp .c_float (tfs_z ),
425422 repeat_penalty = llama_cpp .c_float (repeat_penalty ),
426423 frequency_penalty = llama_cpp .c_float (frequency_penalty ),
427424 presence_penalty = llama_cpp .c_float (presence_penalty ),
425+ mirostat_mode = llama_cpp .c_int (mirostat_mode ),
426+ mirostat_tau = llama_cpp .c_float (mirostat_tau ),
427+ mirostat_eta = llama_cpp .c_float (mirostat_eta ),
428428 )
429429
430430 def generate (
@@ -433,15 +433,13 @@ def generate(
433433 top_k : int ,
434434 top_p : float ,
435435 temp : float ,
436- mirostat_mode : int ,
437- mirostat_tau : float ,
438- mirostat_eta : float ,
439- mirostat_mu : float ,
440- mirostat_m : int ,
441436 repeat_penalty : float ,
437+ reset : bool = True ,
442438 frequency_penalty : float = 0.0 ,
443439 presence_penalty : float = 0.0 ,
444- reset : bool = True ,
440+ mirostat_mode : int = 0 ,
441+ mirostat_tau : float = 5.0 ,
442+ mirostat_eta : float = 0.1 ,
445443 ) -> Generator [
446444 llama_cpp .llama_token , Optional [Sequence [llama_cpp .llama_token ]], None
447445 ]:
@@ -494,14 +492,12 @@ def generate(
494492 top_k = top_k ,
495493 top_p = top_p ,
496494 temp = temp ,
495+ repeat_penalty = repeat_penalty ,
496+ frequency_penalty = frequency_penalty ,
497+ presence_penalty = presence_penalty ,
497498 mirostat_mode = mirostat_mode ,
498499 mirostat_tau = mirostat_tau ,
499500 mirostat_eta = mirostat_eta ,
500- mirostat_mu = mirostat_mu ,
501- mirostat_m = mirostat_m ,
502- frequency_penalty = frequency_penalty ,
503- presence_penalty = presence_penalty ,
504- repeat_penalty = repeat_penalty ,
505501 )
506502 tokens_or_none = yield token
507503 tokens = [token ]
@@ -571,11 +567,6 @@ def _create_completion(
571567 suffix : Optional [str ] = None ,
572568 max_tokens : int = 16 ,
573569 temperature : float = 0.8 ,
574- mirostat_mode : int = 0 ,
575- mirostat_tau : float = 5.0 ,
576- mirostat_eta : float = 0.1 ,
577- mirostat_mu : float = 10 ,
578- mirostat_m : int = 100 ,
579570 top_p : float = 0.95 ,
580571 logprobs : Optional [int ] = None ,
581572 echo : bool = False ,
@@ -585,6 +576,9 @@ def _create_completion(
585576 repeat_penalty : float = 1.1 ,
586577 top_k : int = 40 ,
587578 stream : bool = False ,
579+ mirostat_mode : int = 0 ,
580+ mirostat_tau : float = 5.0 ,
581+ mirostat_eta : float = 0.1 ,
588582 ) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
589583 assert self .ctx is not None
590584 completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
@@ -643,8 +637,6 @@ def _create_completion(
643637 mirostat_mode = mirostat_mode ,
644638 mirostat_tau = mirostat_tau ,
645639 mirostat_eta = mirostat_eta ,
646- mirostat_mu = mirostat_mu ,
647- mirostat_m = mirostat_m ,
648640 frequency_penalty = frequency_penalty ,
649641 presence_penalty = presence_penalty ,
650642 repeat_penalty = repeat_penalty ,
@@ -817,11 +809,6 @@ def create_completion(
817809 suffix : Optional [str ] = None ,
818810 max_tokens : int = 128 ,
819811 temperature : float = 0.8 ,
820- mirostat_mode : int = 0 ,
821- mirostat_tau : float = 5.0 ,
822- mirostat_eta : float = 0.1 ,
823- mirostat_mu : float = 10 ,
824- mirostat_m : int = 100 ,
825812 top_p : float = 0.95 ,
826813 logprobs : Optional [int ] = None ,
827814 echo : bool = False ,
@@ -831,6 +818,9 @@ def create_completion(
831818 repeat_penalty : float = 1.1 ,
832819 top_k : int = 40 ,
833820 stream : bool = False ,
821+ mirostat_mode : int = 0 ,
822+ mirostat_tau : float = 5.0 ,
823+ mirostat_eta : float = 0.1 ,
834824 ) -> Union [Completion , Iterator [CompletionChunk ]]:
835825 """Generate text from a prompt.
836826
@@ -859,11 +849,6 @@ def create_completion(
859849 suffix = suffix ,
860850 max_tokens = max_tokens ,
861851 temperature = temperature ,
862- mirostat_mode = mirostat_mode ,
863- mirostat_tau = mirostat_tau ,
864- mirostat_eta = mirostat_eta ,
865- mirostat_mu = mirostat_mu ,
866- mirostat_m = mirostat_m ,
867852 top_p = top_p ,
868853 logprobs = logprobs ,
869854 echo = echo ,
@@ -873,6 +858,9 @@ def create_completion(
873858 repeat_penalty = repeat_penalty ,
874859 top_k = top_k ,
875860 stream = stream ,
861+ mirostat_mode = mirostat_mode ,
862+ mirostat_tau = mirostat_tau ,
863+ mirostat_eta = mirostat_eta ,
876864 )
877865 if stream :
878866 chunks : Iterator [CompletionChunk ] = completion_or_chunks
@@ -886,11 +874,6 @@ def __call__(
886874 suffix : Optional [str ] = None ,
887875 max_tokens : int = 128 ,
888876 temperature : float = 0.8 ,
889- mirostat_mode : int = 0 ,
890- mirostat_tau : float = 5.0 ,
891- mirostat_eta : float = 0.1 ,
892- mirostat_mu : float = 10 ,
893- mirostat_m : int = 100 ,
894877 top_p : float = 0.95 ,
895878 logprobs : Optional [int ] = None ,
896879 echo : bool = False ,
@@ -900,6 +883,9 @@ def __call__(
900883 repeat_penalty : float = 1.1 ,
901884 top_k : int = 40 ,
902885 stream : bool = False ,
886+ mirostat_mode : int = 0 ,
887+ mirostat_tau : float = 5.0 ,
888+ mirostat_eta : float = 0.1 ,
903889 ) -> Union [Completion , Iterator [CompletionChunk ]]:
904890 """Generate text from a prompt.
905891
@@ -928,11 +914,6 @@ def __call__(
928914 suffix = suffix ,
929915 max_tokens = max_tokens ,
930916 temperature = temperature ,
931- mirostat_mode = mirostat_mode ,
932- mirostat_tau = mirostat_tau ,
933- mirostat_eta = mirostat_eta ,
934- mirostat_mu = mirostat_mu ,
935- mirostat_m = mirostat_m ,
936917 top_p = top_p ,
937918 logprobs = logprobs ,
938919 echo = echo ,
@@ -942,6 +923,9 @@ def __call__(
942923 repeat_penalty = repeat_penalty ,
943924 top_k = top_k ,
944925 stream = stream ,
926+ mirostat_mode = mirostat_mode ,
927+ mirostat_tau = mirostat_tau ,
928+ mirostat_eta = mirostat_eta ,
945929 )
946930
947931 def _convert_text_completion_to_chat (
@@ -1014,6 +998,9 @@ def create_chat_completion(
1014998 presence_penalty : float = 0.0 ,
1015999 frequency_penalty : float = 0.0 ,
10161000 repeat_penalty : float = 1.1 ,
1001+ mirostat_mode : int = 0 ,
1002+ mirostat_tau : float = 5.0 ,
1003+ mirostat_eta : float = 0.1 ,
10171004 ) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
10181005 """Generate a chat completion from a list of messages.
10191006
@@ -1048,6 +1035,9 @@ def create_chat_completion(
10481035 repeat_penalty = repeat_penalty ,
10491036 presence_penalty = presence_penalty ,
10501037 frequency_penalty = frequency_penalty ,
1038+ mirostat_mode = mirostat_mode ,
1039+ mirostat_tau = mirostat_tau ,
1040+ mirostat_eta = mirostat_eta ,
10511041 )
10521042 if stream :
10531043 chunks : Iterator [CompletionChunk ] = completion_or_chunks # type: ignore
0 commit comments