@@ -329,6 +329,8 @@ def __init__(
329329 (n_ctx , self ._n_vocab ), dtype = np .single
330330 )
331331
332+ self ._mirostat_mu = ctypes .c_float (2.0 * 5.0 ) # TODO: Move this to sampling context
333+
332334 @property
333335 def ctx (self ) -> llama_cpp .llama_context_p :
334336 assert self ._ctx .ctx is not None
@@ -516,7 +518,7 @@ def sample(
516518 candidates = self ._candidates ,
517519 tau = mirostat_tau ,
518520 eta = mirostat_eta ,
519- mu = 2.0 * mirostat_tau ,
521+ mu = ctypes . pointer ( self . _mirostat_mu ) ,
520522 m = 100 ,
521523 )
522524 elif mirostat_mode == 2 :
@@ -525,7 +527,7 @@ def sample(
525527 candidates = self ._candidates ,
526528 tau = mirostat_tau ,
527529 eta = mirostat_eta ,
528- mu = 2.0 * mirostat_tau ,
530+ mu = ctypes . pointer ( self . _mirostat_mu )
529531 )
530532 else :
531533 self ._ctx .sample_top_k (candidates = self ._candidates , k = top_k , min_keep = 1 )
@@ -581,6 +583,10 @@ def generate(
581583 Yields:
582584 The generated tokens.
583585 """
586+ # Reset mirostat sampling
587+ self ._mirostat_mu = ctypes .c_float (2.0 * mirostat_tau )
588+
589+ # Check for kv cache prefix match
584590 if reset and self .n_tokens > 0 :
585591 longest_prefix = 0
586592 for a , b in zip (self ._input_ids , tokens [:- 1 ]):
@@ -595,12 +601,15 @@ def generate(
595601 tokens = tokens [longest_prefix :]
596602 self .n_tokens = longest_prefix
597603
604+ # Reset the model state
598605 if reset :
599606 self .reset ()
600607
608+ # Reset the grammar
601609 if grammar is not None :
602610 grammar .reset ()
603611
612+ # Eval and sample
604613 while True :
605614 self .eval (tokens )
606615 token = self .sample (
0 commit comments