Skip to content

Commit de6c684

Browse files
authored
Merge pull request #1982 from rhatdan/max
Add unified --max-tokens CLI argument for output token limiting
2 parents 902df29 + eb6acd2 commit de6c684

File tree

10 files changed

+681
-22
lines changed

10 files changed

+681
-22
lines changed

docs/ramalama-perplexity.1.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ Accelerated images:
8484
pass --group-add keep-groups to podman (default: False)
8585
If GPU device on host system is accessible to user via group access, this option leaks the groups into the container.
8686

87+
#### **--max-tokens**=*integer*
88+
Maximum number of tokens to generate. Set to 0 for unlimited output (default: 0).
89+
This parameter is mapped to the appropriate runtime-specific parameter:
90+
- llama.cpp: `-n` parameter
91+
- MLX: `--max-tokens` parameter
92+
- vLLM: `--max-tokens` parameter
93+
8794
#### **--name**, **-n**
8895
name of the container to run the Model in
8996

docs/ramalama-run.1.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,13 @@ If GPU device on host system is accessible to user via group access, this option
9898
#### **--keepalive**
9999
duration to keep a model loaded (e.g. 5m)
100100

101+
#### **--max-tokens**=*integer*
102+
Maximum number of tokens to generate. Set to 0 for unlimited output (default: 0).
103+
This parameter is mapped to the appropriate runtime-specific parameter:
104+
- llama.cpp: `-n` parameter
105+
- MLX: `--max-tokens` parameter
106+
- vLLM: `--max-tokens` parameter
107+
101108
#### **--mcp**=SERVER_URL
102109
MCP (Model Context Protocol) servers to use for enhanced tool calling capabilities.
103110
Can be specified multiple times to connect to multiple MCP servers.

docs/ramalama-serve.1.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,13 @@ Accelerated images:
142142
pass --group-add keep-groups to podman (default: False)
143143
If GPU device on host system is accessible to user via group access, this option leaks the groups into the container.
144144

145+
#### **--max-tokens**=*integer*
146+
Maximum number of tokens to generate. Set to 0 for unlimited output (default: 0).
147+
This parameter is mapped to the appropriate runtime-specific parameter:
148+
- llama.cpp: `-n` parameter
149+
- MLX: `--max-tokens` parameter
150+
- vLLM: `--max-tokens` parameter
151+
145152
#### **--model-draft**
146153

147154
A draft model is a smaller, faster model that helps accelerate the decoding

docs/ramalama.conf.5.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ specified vllm model runtime.
124124
Pass `--group-add keep-groups` to podman, when using podman.
125125
In some cases this is needed to access the gpu from a rootless container
126126

127+
**max_tokens**=0
128+
129+
Maximum number of tokens to generate. Set to 0 for unlimited output (default: 0).
130+
This parameter is mapped to the appropriate runtime-specific parameter when executing models.
131+
127132
**ngl**=-1
128133

129134
number of gpu layers, 0 means CPU inferencing, 999 means use max layers (default: -1)

ramalama/cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,15 @@ def runtime_options(parser, command):
879879
help="name of container in which the Model will be run",
880880
completer=suppressCompleter,
881881
)
882+
if command in ["run", "perplexity", "serve"]:
883+
parser.add_argument(
884+
"--max-tokens",
885+
dest="max_tokens",
886+
type=int,
887+
default=CONFIG.max_tokens,
888+
help="maximum number of tokens to generate (0 = unlimited)",
889+
completer=suppressCompleter,
890+
)
882891
add_network_argument(parser, dflt=None)
883892
parser.add_argument(
884893
"--ngl",

ramalama/command/context.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,42 @@
1010
class RamalamaArgsContext:
1111

1212
def __init__(self):
13-
self.host: Optional[str] = None
14-
self.port: Optional[int] = None
15-
self.thinking: Optional[bool] = None
16-
self.ctx_size: Optional[int] = None
1713
self.cache_reuse: Optional[int] = None
18-
self.temp: Optional[float] = None
14+
self.container: Optional[bool] = None
15+
self.ctx_size: Optional[int] = None
1916
self.debug: Optional[bool] = None
20-
self.webui: Optional[bool] = None
21-
self.ngl: Optional[int] = None
22-
self.threads: Optional[int] = None
17+
self.host: Optional[str] = None
2318
self.logfile: Optional[str] = None
24-
self.container: Optional[bool] = None
19+
self.max_tokens: Optional[int] = None
2520
self.model_draft: Optional[str] = None
26-
self.seed: Optional[int] = None
21+
self.ngl: Optional[int] = None
22+
self.port: Optional[int] = None
2723
self.runtime_args: Optional[str] = None
24+
self.seed: Optional[int] = None
25+
self.temp: Optional[float] = None
26+
self.thinking: Optional[bool] = None
27+
self.threads: Optional[int] = None
28+
self.webui: Optional[bool] = None
2829

2930
@staticmethod
3031
def from_argparse(args: argparse.Namespace) -> "RamalamaArgsContext":
3132
ctx = RamalamaArgsContext()
32-
ctx.host = getattr(args, "host", None)
33-
ctx.port = getattr(args, "port", None)
34-
ctx.thinking = getattr(args, "thinking", None)
33+
ctx.cache_reuse = getattr(args, "cache_reuse", None)
34+
ctx.container = getattr(args, "container", None)
3535
ctx.ctx_size = getattr(args, "context", None)
36-
ctx.temp = getattr(args, "temp", None)
3736
ctx.debug = getattr(args, "debug", None)
38-
ctx.webui = getattr(args, "webui", None)
39-
ctx.ngl = getattr(args, "ngl", None)
40-
ctx.threads = getattr(args, "threads", None)
37+
ctx.host = getattr(args, "host", None)
4138
ctx.logfile = getattr(args, "logfile", None)
42-
ctx.container = getattr(args, "container", None)
39+
ctx.max_tokens = getattr(args, "max_tokens", None)
4340
ctx.model_draft = getattr(args, "model_draft", None)
44-
ctx.seed = getattr(args, "seed", None)
41+
ctx.ngl = getattr(args, "ngl", None)
42+
ctx.port = getattr(args, "port", None)
4543
ctx.runtime_args = getattr(args, "runtime_args", None)
46-
ctx.cache_reuse = getattr(args, "cache_reuse", None)
44+
ctx.seed = getattr(args, "seed", None)
45+
ctx.temp = getattr(args, "temp", None)
46+
ctx.thinking = getattr(args, "thinking", None)
47+
ctx.threads = getattr(args, "threads", None)
48+
ctx.webui = getattr(args, "webui", None)
4749
return ctx
4850

4951

ramalama/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ class RamalamaSettings:
116116
class BaseConfig:
117117
api: str = "none"
118118
api_key: str = None
119+
cache_reuse: int = 256
119120
carimage: str = "registry.access.redhat.com/ubi10-micro:latest"
120121
container: bool = None # type: ignore
121122
ctx_size: int = 0
122-
cache_reuse: int = 256
123123
default_image: str = DEFAULT_IMAGE
124124
dryrun: bool = False
125125
engine: SUPPORTED_ENGINES | None = field(default_factory=get_default_engine)
@@ -139,6 +139,7 @@ class BaseConfig:
139139
}
140140
)
141141
keep_groups: bool = False
142+
max_tokens: int = 0
142143
ngl: int = -1
143144
ocr: bool = False
144145
port: str = str(DEFAULT_PORT)

ramalama/daemon/service/command_factory.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def _set_defaults(self):
3838
if "temp" not in self.request_args:
3939
self.request_args["temp"] = CONFIG.temp
4040

41+
if "max_tokens" not in self.request_args:
42+
self.request_args["max_tokens"] = CONFIG.max_tokens
43+
4144
if "ngl" not in self.request_args:
4245
self.request_args["ngl"] = CONFIG.ngl
4346

@@ -104,7 +107,7 @@ def _build_llama_serve_command(self) -> list[str]:
104107
if self.request_args.get("webui") == "off":
105108
cmd.extend(["--no-webui"])
106109

107-
if check_nvidia() or check_metal(SimpleNamespace({"container": False})):
110+
if check_nvidia() or check_metal(SimpleNamespace(container=False)):
108111
cmd.extend(["--flash-attn", "on"])
109112

110113
# gpu arguments
@@ -115,4 +118,9 @@ def _build_llama_serve_command(self) -> list[str]:
115118
threads = self.request_args.get("threads")
116119
cmd.extend(["--threads", str(threads)])
117120

121+
# Add max tokens parameter for llama.cpp
122+
max_tokens = self.request_args.get("max_tokens", 0)
123+
if max_tokens > 0:
124+
cmd.extend(["-n", str(max_tokens)])
125+
118126
return cmd

test/e2e/test_cli_max_tokens.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import sys
2+
from contextlib import redirect_stderr, redirect_stdout
3+
from subprocess import CalledProcessError
4+
5+
import pytest
6+
7+
8+
def run_ramalama_direct(args):
9+
"""Run ramalama directly via Python import to avoid installation issues"""
10+
from ramalama.cli import main
11+
12+
# Save original sys.argv
13+
original_argv = sys.argv[:]
14+
15+
try:
16+
sys.argv = ["ramalama"] + args
17+
# Capture stdout by redirecting
18+
import io
19+
20+
stdout_capture = io.StringIO()
21+
stderr_capture = io.StringIO()
22+
23+
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
24+
try:
25+
main()
26+
except SystemExit as e:
27+
# argparse calls sys.exit(), capture the output
28+
stdout_content = stdout_capture.getvalue()
29+
stderr_content = stderr_capture.getvalue()
30+
31+
if e.code != 0: # argparse help exits with 0
32+
# If there was an error, raise CalledProcessError
33+
raise CalledProcessError(e.code, args, stdout_content + stderr_content)
34+
35+
return stdout_content
36+
37+
# If no exception, return the captured output
38+
return stdout_capture.getvalue()
39+
40+
finally:
41+
# Always restore original sys.argv
42+
sys.argv = original_argv
43+
44+
45+
@pytest.mark.e2e
46+
def test_max_tokens_cli_argument_help():
47+
"""Test that --max-tokens argument appears in help for supported commands"""
48+
49+
# Test commands that should have --max-tokens
50+
supported_commands = ["run", "serve", "perplexity"]
51+
52+
for command in supported_commands:
53+
result = run_ramalama_direct([command, "--help"])
54+
assert "--max-tokens" in result, f"--max-tokens should appear in {command} help"
55+
assert "maximum number of tokens to generate" in result, f"Help text should be present in {command}"
56+
57+
58+
@pytest.mark.e2e
59+
def test_max_tokens_argument_parsing():
60+
"""Test that --max-tokens argument is properly parsed"""
61+
62+
# Test that --max-tokens doesn't cause argument parsing errors
63+
# by checking help with the argument present
64+
try:
65+
result = run_ramalama_direct(["run", "--max-tokens", "512", "--help"])
66+
# If we get here, the argument was parsed successfully
67+
assert "--max-tokens" in result
68+
except CalledProcessError as e:
69+
# Should not fail with "unrecognized arguments" for --max-tokens
70+
assert "unrecognized arguments: --max-tokens" not in str(e), f"Argument parsing failed: {e}"
71+
72+
73+
@pytest.mark.e2e
74+
def test_max_tokens_valid_values():
75+
"""Test that max_tokens accepts valid integer values"""
76+
77+
# Test with various valid integer values
78+
valid_values = ["0", "100", "1024", "4096"]
79+
80+
for value in valid_values:
81+
try:
82+
result = run_ramalama_direct(["run", "--max-tokens", value, "--help"])
83+
# Should not raise parsing errors
84+
assert "--max-tokens" in result
85+
except CalledProcessError as e:
86+
assert "unrecognized arguments" not in str(e), f"Should accept valid value {value}"
87+
88+
89+
@pytest.mark.e2e
90+
def test_max_tokens_default_value():
91+
"""Test that max_tokens has a sensible default value"""
92+
93+
result = run_ramalama_direct(["run", "--help"])
94+
95+
# Check that the default is mentioned in help (should show 0)
96+
# Look for the max-tokens line and check it shows default: 0
97+
lines = result.split('\n')
98+
max_tokens_lines = [line for line in lines if '--max-tokens' in line or 'maximum number of tokens' in line]
99+
100+
# Should have at least one line mentioning max-tokens
101+
assert max_tokens_lines, "Should have help text for --max-tokens"
102+
103+
104+
@pytest.mark.e2e
105+
def test_max_tokens_invalid_value():
106+
"""Test that max_tokens rejects invalid values"""
107+
108+
# Test with invalid string value (should be rejected by argparse type checking)
109+
try:
110+
run_ramalama_direct(["run", "--max-tokens", "invalid", "--help"])
111+
# If no exception, this is unexpected but we'll allow it for now
112+
except CalledProcessError as e:
113+
# Should fail due to invalid type conversion, not unrecognized argument
114+
assert "unrecognized arguments: --max-tokens" not in str(e)
115+
# argparse should complain about invalid int conversion
116+
assert "invalid" in str(e) or "int" in str(e).lower()
117+
118+
119+
@pytest.mark.e2e
120+
def test_max_tokens_negative_value():
121+
"""Test that max_tokens accepts negative values (though they may be treated as 0)"""
122+
123+
# Negative values should be accepted by argparse (int type allows them)
124+
try:
125+
result = run_ramalama_direct(["run", "--max-tokens", "-1", "--help"])
126+
# Should not raise parsing errors
127+
assert "--max-tokens" in result
128+
except CalledProcessError as e:
129+
# Should not fail with "unrecognized arguments" for --max-tokens
130+
assert "unrecognized arguments: --max-tokens" not in str(e)

0 commit comments

Comments
 (0)