Skip to content

Commit 4263fc6

Browse files
authored
Merge pull request #93 from WecoAI/feature/byok
Add --api-key parameter that allows users to bring their own API keys for calling models. This can be used with the run and resume commands. Multiple keys can be provided, e.g.: --api-key openai=<key> gemini=<key>
2 parents 5faf6fb + b50c495 commit 4263fc6

File tree

8 files changed

+373
-34
lines changed

8 files changed

+373
-34
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ For more advanced examples, including [Triton](/examples/triton/README.md), [CUD
9595
| `--eval-timeout` | Timeout in seconds for each step in evaluation. | No timeout (unlimited) | `--eval-timeout 3600` |
9696
| `--save-logs` | Save execution output from each optimization step to disk. Creates timestamped directories with raw output files and a JSONL index for tracking execution history. | `False` | `--save-logs` |
9797
| `--apply-change` | Automatically apply the best solution to the source file without prompting. | `False` | `--apply-change` |
98+
| `--api-key` | API keys for LLM providers (BYOK). Format: `provider=key`. Can specify multiple providers. | `None` | `--api-key openai=sk-xxx` |
9899

99100
---
100101

@@ -149,6 +150,7 @@ Arguments for `weco resume`:
149150
|----------|-------------|---------|
150151
| `run-id` | The UUID of the run to resume (shown at the start of each run) | `0002e071-1b67-411f-a514-36947f0c4b31` |
151152
| `--apply-change` | Automatically apply the best solution to the source file without prompting | `--apply-change` |
153+
| `--api-key` | (Optional) API keys for LLM providers (BYOK). Format: `provider=key` | `--api-key openai=sk-xxx` |
152154

153155
Notes:
154156
- Works only for interrupted runs (status: `error`, `terminated`, etc.).

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies = [
1818
"gitingest",
1919
"fastapi",
2020
"slowapi",
21-
"psutil",
21+
"psutil"
2222
]
2323
keywords = ["AI", "Code Optimization", "Code Generation"]
2424
classifiers = [
@@ -34,7 +34,7 @@ weco = "weco.cli:main"
3434
Homepage = "https://github.com/WecoAI/weco-cli"
3535

3636
[project.optional-dependencies]
37-
dev = ["ruff", "build", "setuptools_scm"]
37+
dev = ["ruff", "build", "setuptools_scm", "pytest>=7.0.0"]
3838

3939
[tool.setuptools]
4040
packages = ["weco"]

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for weco CLI."""

tests/test_byok.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""Tests to verify API keys are correctly passed through the system and sent to the API."""
2+
3+
import pytest
4+
from unittest.mock import patch, MagicMock
5+
from rich.console import Console
6+
7+
from weco.api import start_optimization_run, evaluate_feedback_then_suggest_next_solution
8+
9+
10+
class TestApiKeysInStartOptimizationRun:
11+
"""Test that api_keys are correctly included in start_optimization_run requests."""
12+
13+
@pytest.fixture
14+
def mock_console(self):
15+
"""Create a mock console for testing."""
16+
return MagicMock(spec=Console)
17+
18+
@pytest.fixture
19+
def base_params(self, mock_console):
20+
"""Base parameters for start_optimization_run."""
21+
return {
22+
"console": mock_console,
23+
"source_code": "print('hello')",
24+
"source_path": "test.py",
25+
"evaluation_command": "python test.py",
26+
"metric_name": "accuracy",
27+
"maximize": True,
28+
"steps": 10,
29+
"code_generator_config": {"model": "o4-mini"},
30+
"evaluator_config": {"model": "o4-mini"},
31+
"search_policy_config": {"num_drafts": 2},
32+
}
33+
34+
@patch("weco.api.requests.post")
35+
def test_api_keys_included_in_request(self, mock_post, base_params):
36+
"""Test that api_keys are included in the request JSON when provided."""
37+
mock_response = MagicMock()
38+
mock_response.json.return_value = {
39+
"run_id": "test-run-id",
40+
"solution_id": "test-solution-id",
41+
"code": "print('hello')",
42+
"plan": "test plan",
43+
}
44+
mock_response.raise_for_status = MagicMock()
45+
mock_post.return_value = mock_response
46+
47+
api_keys = {"openai": "sk-test-key", "anthropic": "sk-ant-test"}
48+
start_optimization_run(**base_params, api_keys=api_keys)
49+
50+
# Verify the request was made with api_keys in the JSON payload
51+
mock_post.assert_called_once()
52+
call_kwargs = mock_post.call_args
53+
request_json = call_kwargs.kwargs["json"]
54+
assert "api_keys" in request_json
55+
assert request_json["api_keys"] == {"openai": "sk-test-key", "anthropic": "sk-ant-test"}
56+
57+
@patch("weco.api.requests.post")
58+
def test_api_keys_not_included_when_none(self, mock_post, base_params):
59+
"""Test that api_keys field is not included when api_keys is None."""
60+
mock_response = MagicMock()
61+
mock_response.json.return_value = {
62+
"run_id": "test-run-id",
63+
"solution_id": "test-solution-id",
64+
"code": "print('hello')",
65+
"plan": "test plan",
66+
}
67+
mock_response.raise_for_status = MagicMock()
68+
mock_post.return_value = mock_response
69+
70+
start_optimization_run(**base_params, api_keys=None)
71+
72+
# Verify the request was made without api_keys
73+
mock_post.assert_called_once()
74+
call_kwargs = mock_post.call_args
75+
request_json = call_kwargs.kwargs["json"]
76+
assert "api_keys" not in request_json
77+
78+
@patch("weco.api.requests.post")
79+
def test_api_keys_not_included_when_empty_dict(self, mock_post, base_params):
80+
"""Test that api_keys field is not included when api_keys is an empty dict."""
81+
mock_response = MagicMock()
82+
mock_response.json.return_value = {
83+
"run_id": "test-run-id",
84+
"solution_id": "test-solution-id",
85+
"code": "print('hello')",
86+
"plan": "test plan",
87+
}
88+
mock_response.raise_for_status = MagicMock()
89+
mock_post.return_value = mock_response
90+
91+
# Empty dict is falsy, so api_keys should not be included
92+
start_optimization_run(**base_params, api_keys={})
93+
94+
mock_post.assert_called_once()
95+
call_kwargs = mock_post.call_args
96+
request_json = call_kwargs.kwargs["json"]
97+
assert "api_keys" not in request_json
98+
99+
100+
class TestApiKeysInEvaluateFeedbackThenSuggest:
101+
"""Test that api_keys are correctly included in evaluate_feedback_then_suggest_next_solution requests."""
102+
103+
@pytest.fixture
104+
def mock_console(self):
105+
"""Create a mock console for testing."""
106+
return MagicMock(spec=Console)
107+
108+
@patch("weco.api.requests.post")
109+
def test_api_keys_included_in_suggest_request(self, mock_post, mock_console):
110+
"""Test that api_keys are included in the suggest request JSON when provided."""
111+
mock_response = MagicMock()
112+
mock_response.json.return_value = {
113+
"run_id": "test-run-id",
114+
"solution_id": "new-solution-id",
115+
"code": "print('improved')",
116+
"plan": "improvement plan",
117+
"is_done": False,
118+
}
119+
mock_response.raise_for_status = MagicMock()
120+
mock_post.return_value = mock_response
121+
122+
api_keys = {"openai": "sk-test-key"}
123+
evaluate_feedback_then_suggest_next_solution(
124+
console=mock_console,
125+
run_id="test-run-id",
126+
step=1,
127+
execution_output="accuracy: 0.95",
128+
auth_headers={"Authorization": "Bearer test-token"},
129+
api_keys=api_keys,
130+
)
131+
132+
mock_post.assert_called_once()
133+
call_kwargs = mock_post.call_args
134+
request_json = call_kwargs.kwargs["json"]
135+
assert "api_keys" in request_json
136+
assert request_json["api_keys"] == {"openai": "sk-test-key"}
137+
138+
@patch("weco.api.requests.post")
139+
def test_api_keys_not_included_in_suggest_when_none(self, mock_post, mock_console):
140+
"""Test that api_keys field is not included in suggest request when api_keys is None."""
141+
mock_response = MagicMock()
142+
mock_response.json.return_value = {
143+
"run_id": "test-run-id",
144+
"solution_id": "new-solution-id",
145+
"code": "print('improved')",
146+
"plan": "improvement plan",
147+
"is_done": False,
148+
}
149+
mock_response.raise_for_status = MagicMock()
150+
mock_post.return_value = mock_response
151+
152+
evaluate_feedback_then_suggest_next_solution(
153+
console=mock_console,
154+
run_id="test-run-id",
155+
step=1,
156+
execution_output="accuracy: 0.95",
157+
auth_headers={"Authorization": "Bearer test-token"},
158+
api_keys=None,
159+
)
160+
161+
mock_post.assert_called_once()
162+
call_kwargs = mock_post.call_args
163+
request_json = call_kwargs.kwargs["json"]
164+
assert "api_keys" not in request_json
165+
166+
@patch("weco.api.requests.post")
167+
def test_api_keys_not_included_in_suggest_when_empty_dict(self, mock_post, mock_console):
168+
"""Test that api_keys field is not included in suggest request when api_keys is None."""
169+
mock_response = MagicMock()
170+
mock_response.json.return_value = {
171+
"run_id": "test-run-id",
172+
"solution_id": "new-solution-id",
173+
"code": "print('improved')",
174+
"plan": "improvement plan",
175+
"is_done": False,
176+
}
177+
mock_response.raise_for_status = MagicMock()
178+
mock_post.return_value = mock_response
179+
180+
evaluate_feedback_then_suggest_next_solution(
181+
console=mock_console,
182+
run_id="test-run-id",
183+
step=1,
184+
execution_output="accuracy: 0.95",
185+
auth_headers={"Authorization": "Bearer test-token"},
186+
api_keys={},
187+
)
188+
189+
mock_post.assert_called_once()
190+
call_kwargs = mock_post.call_args
191+
request_json = call_kwargs.kwargs["json"]
192+
assert "api_keys" not in request_json

tests/test_cli.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Tests for CLI functions, particularly parse_api_keys."""
2+
3+
import pytest
4+
from weco.cli import parse_api_keys
5+
6+
7+
class TestParseApiKeys:
8+
"""Test cases for parse_api_keys function."""
9+
10+
def test_parse_api_keys_none(self):
11+
"""Test that None input returns empty dict."""
12+
result = parse_api_keys(None)
13+
assert result == {}
14+
assert isinstance(result, dict)
15+
16+
def test_parse_api_keys_empty_list(self):
17+
"""Test that empty list returns empty dict."""
18+
result = parse_api_keys([])
19+
assert result == {}
20+
assert isinstance(result, dict)
21+
22+
def test_parse_api_keys_single_key(self):
23+
"""Test parsing a single API key."""
24+
result = parse_api_keys(["openai=sk-xxx"])
25+
assert result == {"openai": "sk-xxx"}
26+
27+
def test_parse_api_keys_multiple_keys(self):
28+
"""Test parsing multiple API keys."""
29+
result = parse_api_keys(["openai=sk-xxx", "anthropic=sk-ant-yyy"])
30+
assert result == {"openai": "sk-xxx", "anthropic": "sk-ant-yyy"}
31+
32+
def test_parse_api_keys_whitespace_handling(self):
33+
"""Test that whitespace is stripped from provider and key."""
34+
result = parse_api_keys([" openai = sk-xxx ", " anthropic = sk-ant-yyy "])
35+
assert result == {"openai": "sk-xxx", "anthropic": "sk-ant-yyy"}
36+
37+
def test_parse_api_keys_key_contains_equals(self):
38+
"""Test that keys containing '=' are handled correctly (split on first '=' only)."""
39+
result = parse_api_keys(["openai=sk-xxx=extra=more"])
40+
assert result == {"openai": "sk-xxx=extra=more"}
41+
42+
def test_parse_api_keys_no_equals(self):
43+
"""Test that missing '=' raises ValueError."""
44+
with pytest.raises(ValueError, match="Invalid API key format.*Expected format: 'provider=key'"):
45+
parse_api_keys(["openai"])
46+
47+
def test_parse_api_keys_empty_provider(self):
48+
"""Test that empty provider raises ValueError."""
49+
with pytest.raises(ValueError, match="Provider and key must be non-empty"):
50+
parse_api_keys(["=sk-xxx"])
51+
52+
def test_parse_api_keys_empty_key(self):
53+
"""Test that empty key raises ValueError."""
54+
with pytest.raises(ValueError, match="Provider and key must be non-empty"):
55+
parse_api_keys(["openai="])
56+
57+
def test_parse_api_keys_both_empty(self):
58+
"""Test that both empty provider and key raises ValueError."""
59+
with pytest.raises(ValueError, match="Provider and key must be non-empty"):
60+
parse_api_keys(["="])
61+
62+
def test_parse_api_keys_duplicate_provider(self):
63+
"""Test that duplicate providers overwrite previous value."""
64+
result = parse_api_keys(["openai=sk-xxx", "openai=sk-yyy"])
65+
assert result == {"openai": "sk-yyy"}
66+
67+
def test_parse_api_keys_mixed_case_provider(self):
68+
"""Test that mixed case providers are normalized correctly."""
69+
result = parse_api_keys(["OpenAI=sk-xxx", "ANTHROPIC=sk-ant-yyy"])
70+
assert result == {"openai": "sk-xxx", "anthropic": "sk-ant-yyy"}

weco/api.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -109,31 +109,31 @@ def start_optimization_run(
109109
log_dir: str = ".runs",
110110
auth_headers: dict = {},
111111
timeout: Union[int, Tuple[int, int]] = (10, 3650),
112+
api_keys: Optional[Dict[str, str]] = None,
112113
) -> Optional[Dict[str, Any]]:
113114
"""Start the optimization run."""
114115
with console.status("[bold green]Starting Optimization..."):
115116
try:
116-
response = requests.post(
117-
f"{__base_url__}/runs/",
118-
json={
119-
"source_code": source_code,
120-
"source_path": source_path,
121-
"additional_instructions": additional_instructions,
122-
"objective": {"evaluation_command": evaluation_command, "metric_name": metric_name, "maximize": maximize},
123-
"optimizer": {
124-
"steps": steps,
125-
"code_generator": code_generator_config,
126-
"evaluator": evaluator_config,
127-
"search_policy": search_policy_config,
128-
},
129-
"eval_timeout": eval_timeout,
130-
"save_logs": save_logs,
131-
"log_dir": log_dir,
132-
"metadata": {"client_name": "cli", "client_version": __pkg_version__},
117+
request_json = {
118+
"source_code": source_code,
119+
"source_path": source_path,
120+
"additional_instructions": additional_instructions,
121+
"objective": {"evaluation_command": evaluation_command, "metric_name": metric_name, "maximize": maximize},
122+
"optimizer": {
123+
"steps": steps,
124+
"code_generator": code_generator_config,
125+
"evaluator": evaluator_config,
126+
"search_policy": search_policy_config,
133127
},
134-
headers=auth_headers,
135-
timeout=timeout,
136-
)
128+
"eval_timeout": eval_timeout,
129+
"save_logs": save_logs,
130+
"log_dir": log_dir,
131+
"metadata": {"client_name": "cli", "client_version": __pkg_version__},
132+
}
133+
if api_keys:
134+
request_json["api_keys"] = api_keys
135+
136+
response = requests.post(f"{__base_url__}/runs/", json=request_json, headers=auth_headers, timeout=timeout)
137137
response.raise_for_status()
138138
result = response.json()
139139
# Handle None values for code and plan fields
@@ -156,11 +156,10 @@ def resume_optimization_run(
156156
"""Request the backend to resume an interrupted run."""
157157
with console.status("[bold green]Resuming run..."):
158158
try:
159+
request_json = {"metadata": {"client_name": "cli", "client_version": __pkg_version__}}
160+
159161
response = requests.post(
160-
f"{__base_url__}/runs/{run_id}/resume",
161-
json={"metadata": {"client_name": "cli", "client_version": __pkg_version__}},
162-
headers=auth_headers,
163-
timeout=timeout,
162+
f"{__base_url__}/runs/{run_id}/resume", json=request_json, headers=auth_headers, timeout=timeout
164163
)
165164
response.raise_for_status()
166165
result = response.json()
@@ -180,17 +179,19 @@ def evaluate_feedback_then_suggest_next_solution(
180179
execution_output: str,
181180
auth_headers: dict = {},
182181
timeout: Union[int, Tuple[int, int]] = (10, 3650),
182+
api_keys: Optional[Dict[str, str]] = None,
183183
) -> Dict[str, Any]:
184184
"""Evaluate the feedback and suggest the next solution."""
185185
try:
186186
# Truncate the execution output before sending to backend
187187
truncated_output = truncate_output(execution_output)
188188

189+
request_json = {"execution_output": truncated_output, "metadata": {}}
190+
if api_keys:
191+
request_json["api_keys"] = api_keys
192+
189193
response = requests.post(
190-
f"{__base_url__}/runs/{run_id}/suggest",
191-
json={"execution_output": truncated_output, "metadata": {}},
192-
headers=auth_headers,
193-
timeout=timeout,
194+
f"{__base_url__}/runs/{run_id}/suggest", json=request_json, headers=auth_headers, timeout=timeout
194195
)
195196
response.raise_for_status()
196197
result = response.json()

0 commit comments

Comments
 (0)