Skip to content

Commit 407cdb0

Browse files
committed
Expand CI-safe test coverage
1 parent 29f09cb commit 407cdb0

File tree

6 files changed

+133
-10
lines changed

6 files changed

+133
-10
lines changed

inference/serve.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import random
66
import time
77
from dataclasses import asdict, dataclass
8-
from typing import Any, Dict, Iterable, List
8+
from typing import Any, Dict, Iterable, List, Tuple
99

1010
import ray
1111
import grpc
@@ -14,7 +14,7 @@
1414
from envd.generated import tools_pb2, tools_pb2_grpc
1515
from .config import SamplerConfig, StorageConfig
1616
from .model import create_sampler, ensure_json_plan
17-
from .storage import create_storage
17+
from .storage import create_storage, write_rollout_records
1818

1919

2020
@dataclass(slots=True)
@@ -148,7 +148,7 @@ async def batch_rollouts(self, prompts: List[str], *, replicas: int = 3, timeout
148148
prompt_map[ref] = prompt
149149

150150
deadline = time.perf_counter() + timeout_s
151-
results: List[Dict[str, Any]] = []
151+
pairs: List[Tuple[str, Dict[str, Any]]] = []
152152
pending = list(object_refs)
153153

154154
while pending and time.perf_counter() < deadline:
@@ -160,16 +160,14 @@ async def batch_rollouts(self, prompts: List[str], *, replicas: int = 3, timeout
160160
result = await ref
161161
except Exception as exc: # noqa: BLE001
162162
result = {"error": str(exc)}
163-
record = {"prompt": prompt, "result": result, "timestamp_s": time.time()}
164-
results.append(record)
163+
pairs.append((prompt, result))
165164
prompt_map.pop(ref, None)
166165

167166
for ref in pending:
168167
ray.cancel(ref, force=True)
169168

170-
if results:
171-
await self._storage.write(results)
172-
return results
169+
records = await write_rollout_records(self._storage, pairs)
170+
return records
173171

174172

175173
def bootstrap_ray(env_hosts: List[str], *, num_samplers: int = 2):

inference/storage.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import json
55
import uuid
66
from pathlib import Path
7-
from typing import Any, Dict, Iterable
7+
import time
8+
from typing import Any, Dict, Iterable, List, Tuple, Callable
89

910
from .config import StorageConfig
1011

@@ -60,6 +61,25 @@ async def write(self, records: Iterable[Dict[str, Any]]) -> None:
6061
return None
6162

6263

64+
def build_rollout_records(entries: Iterable[Tuple[str, Dict[str, Any]]], now: Callable[[], float] | None = None) -> List[Dict[str, Any]]:
65+
timestamp = now or time.time
66+
return [
67+
{
68+
"prompt": prompt,
69+
"result": result,
70+
"timestamp_s": timestamp(),
71+
}
72+
for prompt, result in entries
73+
]
74+
75+
76+
async def write_rollout_records(storage: StorageWriter, entries: Iterable[Tuple[str, Dict[str, Any]]], now: Callable[[], float] | None = None) -> List[Dict[str, Any]]:
77+
records = build_rollout_records(entries, now=now)
78+
if records:
79+
await storage.write(records)
80+
return records
81+
82+
6383
def create_storage(cfg: StorageConfig) -> StorageWriter:
6484
try:
6585
if cfg.kind == "s3" and cfg.s3_bucket:
@@ -73,4 +93,4 @@ def create_storage(cfg: StorageConfig) -> StorageWriter:
7393
return NoOpWriter()
7494

7595

76-
__all__ = ["create_storage", "StorageWriter"]
96+
__all__ = ["create_storage", "StorageWriter", "build_rollout_records", "write_rollout_records"]

tests/test_controller_bridge.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import asyncio
2+
3+
from inference.storage import build_rollout_records, write_rollout_records
4+
5+
6+
class InMemoryStorage:
7+
def __init__(self):
8+
self.records = None
9+
10+
async def write(self, records):
11+
self.records = list(records)
12+
13+
14+
def test_build_rollout_records_structure():
15+
records = build_rollout_records([("prompt-1", {"ok": True})], now=lambda: 123.0)
16+
assert records[0]["prompt"] == "prompt-1"
17+
assert records[0]["result"] == {"ok": True}
18+
assert records[0]["timestamp_s"] == 123.0
19+
20+
21+
def test_write_rollout_records_persists():
22+
storage = InMemoryStorage()
23+
records = asyncio.run(write_rollout_records(storage, [("prompt-2", {"latency": 1.0})], now=lambda: 456.0))
24+
assert storage.records == records
25+
assert records[0]["timestamp_s"] == 456.0

tests/test_reward.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from trainer.reward import compute_reward
2+
3+
4+
def test_compute_reward_balances_signals():
5+
metrics = {
6+
"tests_passed": 1,
7+
"lint_improvement": 0.5,
8+
"parallel_groups": 3,
9+
"regressions": 0,
10+
}
11+
reward = compute_reward(metrics, latency_s=10.0)
12+
assert reward > 0
13+
14+
15+
def test_compute_reward_penalizes_regressions():
16+
metrics = {
17+
"tests_passed": 0,
18+
"lint_improvement": 0,
19+
"parallel_groups": 0,
20+
"regressions": 2,
21+
}
22+
reward = compute_reward(metrics, latency_s=5.0)
23+
assert reward < 0

tests/test_scripts.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import subprocess
2+
from pathlib import Path
3+
4+
import pytest
5+
6+
SCRIPTS = [
7+
Path("scripts/firecracker/build_base.sh"),
8+
Path("scripts/firecracker/create_template.sh"),
9+
Path("scripts/firecracker/launch_envs.sh"),
10+
]
11+
12+
13+
@pytest.mark.parametrize("script_path", SCRIPTS)
14+
def test_firecracker_scripts_shellcheck(script_path):
15+
full_path = Path(__file__).resolve().parents[1] / script_path
16+
result = subprocess.run(["bash", "-n", str(full_path)], capture_output=True, text=True)
17+
assert result.returncode == 0, result.stderr

tests/test_vllm_sampler.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import asyncio
2+
import json
3+
4+
import httpx
5+
import pytest
6+
7+
from inference.config import SamplerConfig
8+
from inference.model import create_sampler
9+
10+
11+
class DummyResponse:
12+
def __init__(self, payload):
13+
self._payload = payload
14+
15+
def raise_for_status(self):
16+
return None
17+
18+
def json(self):
19+
return self._payload
20+
21+
22+
@pytest.mark.asyncio
23+
async def test_vllm_sampler_formats_plan(monkeypatch):
24+
cfg = SamplerConfig(kind="vllm-openai", vllm_rpc_host="localhost")
25+
sampler = create_sampler(cfg)
26+
27+
async def fake_post(self, url, json):
28+
assert "Fix bug" in json["prompt"]
29+
return DummyResponse({"choices": [{"text": "{\"then\": []}"}]})
30+
31+
async def fake_close(self):
32+
return None
33+
34+
monkeypatch.setattr(httpx.AsyncClient, "post", fake_post, raising=False)
35+
monkeypatch.setattr(httpx.AsyncClient, "aclose", fake_close, raising=False)
36+
37+
plan = await sampler.sample("Fix bug")
38+
assert json.loads(plan)["then"] == []
39+
40+
await sampler.close()

0 commit comments

Comments
 (0)