1212import os
1313from importlib .metadata import version
1414from pathlib import Path
15- from typing import Any , Optional
15+ from typing import Any , Callable , Optional
1616
1717import lm_eval
1818import mlx .core as mx
2525
2626from .generate import batch_generate
2727from .models .cache import make_prompt_cache
28+ from .sample_utils import make_sampler
2829from .utils import common_prefix_len , load
2930
3031DEFAULT_MAX_TOKENS = 8192
@@ -38,6 +39,13 @@ def _rstrip_until(s, untils):
3839 return s [: min (f )]
3940
4041
42+ def _lstrip (s , pattern ):
43+ """Truncate the prefix of the string after the first occurrence of pattern."""
44+ if (idx := s .find (pattern )) != - 1 :
45+ return s [idx + len (pattern ) :]
46+ return s
47+
48+
4149def _pad_inputs (inputs ):
4250 lengths = np .array ([len (x ) for x in inputs ])
4351 maxlen = lengths .max ()
@@ -73,6 +81,7 @@ def __init__(
7381 max_tokens : Optional [int ] = None ,
7482 use_chat_template : Optional [bool ] = None ,
7583 trust_remote_code : bool = False ,
84+ sampler : Optional [Callable [[mx .array ], mx .array ]] = None ,
7685 ) -> None :
7786 super ().__init__ ()
7887 tokenizer_config = {"trust_remote_code" : True if trust_remote_code else None }
@@ -84,6 +93,7 @@ def __init__(
8493 self .use_chat_template = use_chat_template
8594 if use_chat_template is None :
8695 self .use_chat_template = self .tokenizer .chat_template is not None
96+ self ._sampler = sampler
8797
8898 def _process_prompt (self , prompt , step_size : int = 2048 ):
8999 prompt = mx .array (prompt )[None ]
@@ -338,12 +348,13 @@ def generate_until(self, requests) -> list[str]:
338348 prompts = contexts ,
339349 max_tokens = max_tokens ,
340350 verbose = True ,
351+ sampler = self ._sampler ,
341352 ).texts
342353
343354 for e , (text , opt ) in enumerate (zip (completions , options )):
344- until = opt ["until" ]
345- if any ( u in text for u in until ) :
346- completions [e ] = _rstrip_until (text , until )
355+ completions [ e ] = _rstrip_until ( text , opt ["until" ])
356+ if self . tokenizer . has_thinking :
357+ completions [e ] = _lstrip (text , self . tokenizer . think_end )
347358
348359 # Gather the completions
349360 if group .size () > 1 :
@@ -438,7 +449,9 @@ def main():
438449 action = "store_true" ,
439450 help = "Enable trusting remote code for tokenizer" ,
440451 )
441-
452+ parser .add_argument ("--temp" , type = float , default = 0.0 , help = "Sampling temperature" )
453+ parser .add_argument ("--top-p" , type = float , default = 1.0 , help = "Sampling top-p" )
454+ parser .add_argument ("--top-k" , type = int , default = 0 , help = "Sampling top-k" )
442455 args = parser .parse_args ()
443456
444457 output_dir = Path (args .output_dir )
@@ -455,11 +468,17 @@ def main():
455468 if world .size () > 1 and world .rank () == 0 :
456469 print (f"Evaluating with { world .size ()} nodes" )
457470
471+ sampler = make_sampler (
472+ temp = args .temp ,
473+ top_p = args .top_p ,
474+ top_k = args .top_k ,
475+ )
458476 lm = MLXLM (
459477 args .model ,
460478 max_tokens = args .max_tokens ,
461479 use_chat_template = args .apply_chat_template ,
462480 trust_remote_code = args .trust_remote_code ,
481+ sampler = sampler ,
463482 )
464483 MLXLM .apply_chat_template = chat_template_fn (** args .chat_template_args )
465484
0 commit comments