|
27 | 27 | from .models.cache import make_prompt_cache |
28 | 28 | from .utils import common_prefix_len, load |
29 | 29 |
|
| 30 | +DEFAULT_MAX_TOKENS = 8192 |
| 31 | + |
30 | 32 |
|
31 | 33 | def _rstrip_until(s, untils): |
32 | 34 | """Limit a string <s> to the first occurrence of any substring in untils.""" |
@@ -68,7 +70,7 @@ class MLXLM(LM): |
68 | 70 | def __init__( |
69 | 71 | self, |
70 | 72 | path_or_hf_repo: str, |
71 | | - max_tokens: int, |
| 73 | + max_tokens: Optional[int] = None, |
72 | 74 | use_chat_template: Optional[bool] = None, |
73 | 75 | trust_remote_code: bool = False, |
74 | 76 | ) -> None: |
@@ -182,7 +184,8 @@ def loglikelihood(self, requests) -> list[tuple[float, bool]]: |
182 | 184 | max_completed_l = max(len(s) for s in full_sequences) |
183 | 185 |
|
184 | 186 | # compute truncation length |
185 | | - truncation = max(0, max_completed_l - self._max_tokens - 1) |
| 187 | + max_tokens = self._max_tokens or DEFAULT_MAX_TOKENS |
| 188 | + truncation = max(0, max_completed_l - max_tokens - 1) |
186 | 189 | orig_prefix_l = len(prefix) |
187 | 190 | prefix_l = max(len(prefix) - truncation, 0) |
188 | 191 | prefix = prefix[len(prefix) - prefix_l :] |
@@ -324,7 +327,10 @@ def generate_until(self, requests) -> list[str]: |
324 | 327 | ] |
325 | 328 |
|
326 | 329 | # TODO consider multi-token, per-prompt stop conditions |
327 | | - max_tokens = [opt.get("max_gen_toks", self._max_tokens) for opt in options] |
| 330 | + max_tokens = [ |
| 331 | + self._max_tokens or opt.get("max_gen_tokens", DEFAULT_MAX_TOKENS) |
| 332 | + for opt in options |
| 333 | + ] |
328 | 334 |
|
329 | 335 | completions = batch_generate( |
330 | 336 | model=self._model, |
@@ -388,8 +394,9 @@ def main(): |
388 | 394 | parser.add_argument( |
389 | 395 | "--max-tokens", |
390 | 396 | type=int, |
391 | | - help="Maximum number of tokens to generate.", |
392 | | - default=8912, |
| 397 | + help="Maximum number of tokens to generate. When set, this value takes" |
| 398 | + " precedence over task specific defaults.", |
| 399 | + default=None, |
393 | 400 | ) |
394 | 401 | parser.add_argument( |
395 | 402 | "--limit", |
|
0 commit comments