Skip to content

Commit f369773

Browse files
authored
fix eval thinking (#578)
1 parent 1e8fca4 commit f369773

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

mlx_lm/evaluate.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from .models.cache import make_prompt_cache
2828
from .utils import common_prefix_len, load
2929

30+
DEFAULT_MAX_TOKENS = 8192
31+
3032

3133
def _rstrip_until(s, untils):
3234
"""Limit a string <s> to the first occurrence of any substring in untils."""
@@ -68,7 +70,7 @@ class MLXLM(LM):
6870
def __init__(
6971
self,
7072
path_or_hf_repo: str,
71-
max_tokens: int,
73+
max_tokens: Optional[int] = None,
7274
use_chat_template: Optional[bool] = None,
7375
trust_remote_code: bool = False,
7476
) -> None:
@@ -182,7 +184,8 @@ def loglikelihood(self, requests) -> list[tuple[float, bool]]:
182184
max_completed_l = max(len(s) for s in full_sequences)
183185

184186
# compute truncation length
185-
truncation = max(0, max_completed_l - self._max_tokens - 1)
187+
max_tokens = self._max_tokens or DEFAULT_MAX_TOKENS
188+
truncation = max(0, max_completed_l - max_tokens - 1)
186189
orig_prefix_l = len(prefix)
187190
prefix_l = max(len(prefix) - truncation, 0)
188191
prefix = prefix[len(prefix) - prefix_l :]
@@ -324,7 +327,10 @@ def generate_until(self, requests) -> list[str]:
324327
]
325328

326329
# TODO consider multi-token, per-prompt stop conditions
327-
max_tokens = [opt.get("max_gen_toks", self._max_tokens) for opt in options]
330+
max_tokens = [
331+
self._max_tokens or opt.get("max_gen_tokens", DEFAULT_MAX_TOKENS)
332+
for opt in options
333+
]
328334

329335
completions = batch_generate(
330336
model=self._model,
@@ -388,8 +394,9 @@ def main():
388394
parser.add_argument(
389395
"--max-tokens",
390396
type=int,
391-
help="Maximum number of tokens to generate.",
392-
default=8912,
397+
help="Maximum number of tokens to generate. When set, this value takes"
398+
" precedence over task specific defaults.",
399+
default=None,
393400
)
394401
parser.add_argument(
395402
"--limit",

0 commit comments

Comments
 (0)