Skip to content

Commit 4decc4d

Browse files
authored
Add gen options and CoT removal (#587)
* add gen options and CoT removal * comment
1 parent 0d82724 commit 4decc4d

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

mlx_lm/evaluate.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import os
1313
from importlib.metadata import version
1414
from pathlib import Path
15-
from typing import Any, Optional
15+
from typing import Any, Callable, Optional
1616

1717
import lm_eval
1818
import mlx.core as mx
@@ -25,6 +25,7 @@
2525

2626
from .generate import batch_generate
2727
from .models.cache import make_prompt_cache
28+
from .sample_utils import make_sampler
2829
from .utils import common_prefix_len, load
2930

3031
DEFAULT_MAX_TOKENS = 8192
@@ -38,6 +39,13 @@ def _rstrip_until(s, untils):
3839
return s[: min(f)]
3940

4041

42+
def _lstrip(s, pattern):
43+
"""Truncate the prefix of the string after the first occurrence of pattern."""
44+
if (idx := s.find(pattern)) != -1:
45+
return s[idx + len(pattern) :]
46+
return s
47+
48+
4149
def _pad_inputs(inputs):
4250
lengths = np.array([len(x) for x in inputs])
4351
maxlen = lengths.max()
@@ -73,6 +81,7 @@ def __init__(
7381
max_tokens: Optional[int] = None,
7482
use_chat_template: Optional[bool] = None,
7583
trust_remote_code: bool = False,
84+
sampler: Optional[Callable[[mx.array], mx.array]] = None,
7685
) -> None:
7786
super().__init__()
7887
tokenizer_config = {"trust_remote_code": True if trust_remote_code else None}
@@ -84,6 +93,7 @@ def __init__(
8493
self.use_chat_template = use_chat_template
8594
if use_chat_template is None:
8695
self.use_chat_template = self.tokenizer.chat_template is not None
96+
self._sampler = sampler
8797

8898
def _process_prompt(self, prompt, step_size: int = 2048):
8999
prompt = mx.array(prompt)[None]
@@ -338,12 +348,13 @@ def generate_until(self, requests) -> list[str]:
338348
prompts=contexts,
339349
max_tokens=max_tokens,
340350
verbose=True,
351+
sampler=self._sampler,
341352
).texts
342353

343354
for e, (text, opt) in enumerate(zip(completions, options)):
344-
until = opt["until"]
345-
if any(u in text for u in until):
346-
completions[e] = _rstrip_until(text, until)
355+
completions[e] = _rstrip_until(text, opt["until"])
356+
if self.tokenizer.has_thinking:
357+
completions[e] = _lstrip(text, self.tokenizer.think_end)
347358

348359
# Gather the completions
349360
if group.size() > 1:
@@ -438,7 +449,9 @@ def main():
438449
action="store_true",
439450
help="Enable trusting remote code for tokenizer",
440451
)
441-
452+
parser.add_argument("--temp", type=float, default=0.0, help="Sampling temperature")
453+
parser.add_argument("--top-p", type=float, default=1.0, help="Sampling top-p")
454+
parser.add_argument("--top-k", type=int, default=0, help="Sampling top-k")
442455
args = parser.parse_args()
443456

444457
output_dir = Path(args.output_dir)
@@ -455,11 +468,17 @@ def main():
455468
if world.size() > 1 and world.rank() == 0:
456469
print(f"Evaluating with {world.size()} nodes")
457470

471+
sampler = make_sampler(
472+
temp=args.temp,
473+
top_p=args.top_p,
474+
top_k=args.top_k,
475+
)
458476
lm = MLXLM(
459477
args.model,
460478
max_tokens=args.max_tokens,
461479
use_chat_template=args.apply_chat_template,
462480
trust_remote_code=args.trust_remote_code,
481+
sampler=sampler,
463482
)
464483
MLXLM.apply_chat_template = chat_template_fn(**args.chat_template_args)
465484

0 commit comments

Comments
 (0)