@@ -257,6 +257,11 @@ def _sample_top_p_top_k(
257257 top_k : llama_cpp .c_int ,
258258 top_p : llama_cpp .c_float ,
259259 temp : llama_cpp .c_float ,
260+ mirostat_mode : llama_cpp .c_int ,
261+ mirostat_tau : llama_cpp .c_float ,
262+ mirostat_eta : llama_cpp .c_float ,
263+ mirostat_mu : llama_cpp .c_float ,
264+ mirostat_m : llama_cpp .c_int ,
260265 repeat_penalty : llama_cpp .c_float ,
261266 ):
262267 assert self .ctx is not None
@@ -287,7 +292,34 @@ def _sample_top_p_top_k(
287292 candidates = llama_cpp .ctypes .pointer (candidates ),
288293 penalty = repeat_penalty ,
289294 )
290- if float (temp .value ) == 0.0 :
295+ if mirostat_mode == 1 :
296+ llama_cpp .llama_sample_temperature (
297+ ctx = self .ctx ,
298+ candidates = llama_cpp .ctypes .pointer (candidates ),
299+ temp = temp ,
300+ )
301+ llama_cpp .llama_sample_token_mirostat (
302+ ctx = self .ctx ,
303+ candidates = llama_cpp .ctypes .pointer (candidates ),
304+ tau = mirostat_tau ,
305+ eta = mirostat_eta ,
306+ mu = mirostat_mu ,
307+ m = mirostat_m
308+ )
309+ elif mirostat_mode == 2 :
310+ llama_cpp .llama_sample_temperature (
311+ ctx = self .ctx ,
312+ candidates = llama_cpp .ctypes .pointer (candidates ),
313+ temp = temp ,
314+ )
315+ llama_cpp .llama_sample_token_mirostat_v2 (
316+ ctx = self .ctx ,
317+ candidates = llama_cpp .ctypes .pointer (candidates ),
318+ tau = mirostat_tau ,
319+ eta = mirostat_eta ,
320+ mu = mirostat_mu
321+ )
322+ elif float (temp .value ) == 0.0 :
291323 return llama_cpp .llama_sample_token_greedy (
292324 ctx = self .ctx ,
293325 candidates = llama_cpp .ctypes .pointer (candidates ),
@@ -328,6 +360,11 @@ def sample(
328360 top_k : int ,
329361 top_p : float ,
330362 temp : float ,
363+ mirostat_mode : int ,
364+ mirostat_tau : float ,
365+ mirostat_eta : float ,
366+ mirostat_mu : float ,
367+ mirostat_m : int ,
331368 repeat_penalty : float ,
332369 ):
333370 """Sample a token from the model.
@@ -353,6 +390,11 @@ def sample(
353390 top_k = llama_cpp .c_int (top_k ),
354391 top_p = llama_cpp .c_float (top_p ),
355392 temp = llama_cpp .c_float (temp ),
393+ mirostat = llama_cpp .c_int (mirostat_mode ),
394+ mirostat_mu = llama_cpp .c_float (mirostat_mu ),
395+ mirostat_tau = llama_cpp .c_float (mirostat_tau ),
396+ mirostat_eta = llama_cpp .c_float (mirostat_eta ),
397+ mirostat_m = llama_cpp .c_int (mirostat_m ),
356398 repeat_penalty = llama_cpp .c_float (repeat_penalty ),
357399 )
358400
@@ -362,6 +404,11 @@ def generate(
362404 top_k : int ,
363405 top_p : float ,
364406 temp : float ,
407+ mirostat : int ,
408+ mirostat_tau : float ,
409+ mirostat_eta : float ,
410+ mirostat_mu : float ,
411+ mirostat_m : int ,
365412 repeat_penalty : float ,
366413 reset : bool = True ,
367414 ) -> Generator [
@@ -416,6 +463,11 @@ def generate(
416463 top_k = top_k ,
417464 top_p = top_p ,
418465 temp = temp ,
466+ mirostat_mode = mirostat_mode ,
467+ mirostat_tau = mirostat_tau ,
468+ mirostat_eta = mirostat_eta ,
469+ mirostat_mu = mirostat_mu ,
470+ mirostat_m = mirostat_m ,
419471 repeat_penalty = repeat_penalty ,
420472 )
421473 tokens_or_none = yield token
@@ -486,6 +538,11 @@ def _create_completion(
486538 suffix : Optional [str ] = None ,
487539 max_tokens : int = 16 ,
488540 temperature : float = 0.8 ,
541+ mirostat_mode : int = 0 ,
542+ mirostat_tau : float = 5.0 ,
543+ mirostat_eta : float = 0.1 ,
544+ mirostat_mu : float = 10 ,
545+ mirostat_m : int = 100 ,
489546 top_p : float = 0.95 ,
490547 logprobs : Optional [int ] = None ,
491548 echo : bool = False ,
@@ -536,6 +593,11 @@ def _create_completion(
536593 top_k = top_k ,
537594 top_p = top_p ,
538595 temp = temperature ,
596+ mirostat_mode = mirostat_mode ,
597+ mirostat_tau = mirostat_tau ,
598+ mirostat_eta = mirostat_eta ,
599+ mirostat_mu = mirostat_mu ,
600+ mirostat_m = mirostat_m ,
539601 repeat_penalty = repeat_penalty ,
540602 ):
541603 if token == llama_cpp .llama_token_eos ():
@@ -707,6 +769,11 @@ def create_completion(
707769 suffix : Optional [str ] = None ,
708770 max_tokens : int = 128 ,
709771 temperature : float = 0.8 ,
772+ mirostat_mode : int = 0 ,
773+ mirostat_tau : float = 5.0 ,
774+ mirostat_eta : float = 0.1 ,
775+ mirostat_mu : float = 10 ,
776+ mirostat_m : int = 100 ,
710777 top_p : float = 0.95 ,
711778 logprobs : Optional [int ] = None ,
712779 echo : bool = False ,
@@ -742,6 +809,11 @@ def create_completion(
742809 suffix = suffix ,
743810 max_tokens = max_tokens ,
744811 temperature = temperature ,
812+ mirostat_mode = mirostat_mode ,
813+ mirostat_tau = mirostat_tau ,
814+ mirostat_eta = mirostat_eta ,
815+ mirostat_mu = mirostat_mu ,
816+ mirostat_m = mirostat_m ,
745817 top_p = top_p ,
746818 logprobs = logprobs ,
747819 echo = echo ,
@@ -762,6 +834,11 @@ def __call__(
762834 suffix : Optional [str ] = None ,
763835 max_tokens : int = 128 ,
764836 temperature : float = 0.8 ,
837+ mirostat_mode : int = 0 ,
838+ mirostat_tau : float = 5.0 ,
839+ mirostat_eta : float = 0.1 ,
840+ mirostat_mu : float = 10 ,
841+ mirostat_m : int = 100 ,
765842 top_p : float = 0.95 ,
766843 logprobs : Optional [int ] = None ,
767844 echo : bool = False ,
@@ -797,6 +874,11 @@ def __call__(
797874 suffix = suffix ,
798875 max_tokens = max_tokens ,
799876 temperature = temperature ,
877+ mirostat_mode = mirostat_mode ,
878+ mirostat_tau = mirostat_tau ,
879+ mirostat_eta = mirostat_eta ,
880+ mirostat_mu = mirostat_mu ,
881+ mirostat_m = mirostat_m ,
800882 top_p = top_p ,
801883 logprobs = logprobs ,
802884 echo = echo ,
0 commit comments