Skip to content

Commit ba2c6b9

Browse files
committed
Enhance CI linting and add coverage-focused tests
1 parent 407cdb0 commit ba2c6b9

File tree

16 files changed

+149
-37
lines changed

16 files changed

+149
-37
lines changed

.github/workflows/ci.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,9 @@ jobs:
1717
run: |
1818
python -m pip install --upgrade pip
1919
pip install -r requirements-test.txt
20-
- name: Run tests
21-
run: pytest --maxfail=1 --disable-warnings
20+
- name: Run format & lint checks
21+
run: |
22+
ruff check .
23+
mypy --ignore-missing-imports envd inference trainer
24+
- name: Run tests with coverage
25+
run: pytest --maxfail=1 --disable-warnings --cov --cov-report=term-missing

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ checkpoints/
99
dist/
1010
build/
1111
*.log
12+
/.coverage
1213
/.venv/

envd/server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
import subprocess
55
import time
6-
from concurrent import futures
76
from pathlib import Path
87
from typing import List
98

inference/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import json
4-
from typing import Any, Protocol
4+
from typing import Protocol
55

66
import httpx
77

inference/serve.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,27 @@
55
import random
66
import time
77
from dataclasses import asdict, dataclass
8-
from typing import Any, Dict, Iterable, List, Tuple
8+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast
99

10-
import ray
1110
import grpc
11+
import ray
12+
from envd.generated import tools_pb2 as tools_pb2_module
13+
from envd.generated import tools_pb2_grpc as tools_pb2_grpc_module
1214
from google.protobuf.json_format import MessageToDict
1315

14-
from envd.generated import tools_pb2, tools_pb2_grpc
1516
from .config import SamplerConfig, StorageConfig
1617
from .model import create_sampler, ensure_json_plan
1718
from .storage import create_storage, write_rollout_records
1819

20+
tools_pb2 = cast(Any, tools_pb2_module)
21+
tools_pb2_grpc = cast(Any, tools_pb2_grpc_module)
22+
ray = cast(Any, ray)
23+
24+
if TYPE_CHECKING:
25+
from ray.actor import ActorHandle
26+
else: # pragma: no cover - typing fallback
27+
ActorHandle = Any
28+
1929

2030
@dataclass(slots=True)
2131
class ToolSpec:
@@ -39,15 +49,14 @@ def build_request(tool: str, args: Dict[str, Any]) -> Any:
3949
return spec.request_cls(**args)
4050

4151

42-
def stub_method(stub: tools_pb2_grpc.ToolsStub, tool: str):
52+
def stub_method(stub: Any, tool: str):
4353
spec = TOOL_SPECS.get(tool)
4454
if not spec:
4555
raise ValueError(f"unknown tool: {tool}")
4656
return getattr(stub, spec.attr)
4757

4858

49-
@ray.remote(num_cpus=0.05)
50-
class EnvClient:
59+
class EnvClientImpl:
5160
def __init__(self, host: str):
5261
self._host = host
5362
self._channel = grpc.aio.insecure_channel(host)
@@ -58,9 +67,9 @@ async def call(self, tool: str, args: Dict[str, Any], *, timeout: float | None =
5867
rpc = stub_method(self._stub, tool)
5968
try:
6069
response = await rpc(request, timeout=timeout)
61-
payload = MessageToDict(response, preserving_proto_field_name=True, including_default_value_fields=True)
70+
payload = MessageToDict(response, preserving_proto_field_name=True)
6271
return {"tool": tool, "ok": True, "response": payload}
63-
except grpc.aio.AioRpcError as exc: # noqa: D
72+
except grpc.aio.AioRpcError as exc: # noqa: D401
6473
return {
6574
"tool": tool,
6675
"ok": False,
@@ -71,9 +80,12 @@ async def call(self, tool: str, args: Dict[str, Any], *, timeout: float | None =
7180
async def close(self) -> None:
7281
await self._channel.close()
7382

83+
@classmethod
84+
def remote(cls, host: str) -> Any: # pragma: no cover - typing helper
85+
return cls(host)
7486

75-
@ray.remote(num_cpus=0.1)
76-
class ModelSampler:
87+
88+
class ModelSamplerImpl:
7789
def __init__(self, cfg_dict: Dict[str, Any]):
7890
self._cfg = SamplerConfig(**cfg_dict)
7991
self._backend = create_sampler(self._cfg)
@@ -83,9 +95,12 @@ async def sample(self, prompt: str) -> str:
8395
normalized = await ensure_json_plan(plan)
8496
return normalized
8597

98+
@classmethod
99+
def remote(cls, cfg_dict: Dict[str, Any]) -> Any: # pragma: no cover - typing helper
100+
return cls(cfg_dict)
101+
86102

87-
@ray.remote(num_cpus=0.2)
88-
class Sampler:
103+
class SamplerImpl:
89104
def __init__(self, env_hosts: Iterable[str], model_ref):
90105
self._envs = [EnvClient.remote(host) for host in env_hosts]
91106
self._model = model_ref
@@ -130,10 +145,13 @@ async def rollout(self, prompt: str, *, budget_s: float = 45.0, group_timeout_s:
130145
latency = time.perf_counter() - started
131146
return {"trajectory": trajectory, "latency_s": latency}
132147

148+
@classmethod
149+
def remote(cls, env_hosts: Iterable[str], model_ref: Any) -> Any: # pragma: no cover - typing helper
150+
return cls(env_hosts, model_ref)
133151

134-
@ray.remote(num_cpus=0.05)
135-
class Controller:
136-
def __init__(self, sampler_refs: Iterable[ray.actor.ActorHandle], storage_cfg: Dict[str, Any]):
152+
153+
class ControllerImpl:
154+
def __init__(self, sampler_refs: Iterable[ActorHandle], storage_cfg: Dict[str, Any]):
137155
self._samplers = list(sampler_refs)
138156
self._storage = create_storage(StorageConfig(**storage_cfg))
139157

@@ -169,6 +187,22 @@ async def batch_rollouts(self, prompts: List[str], *, replicas: int = 3, timeout
169187
records = await write_rollout_records(self._storage, pairs)
170188
return records
171189

190+
@classmethod
191+
def remote(cls, sampler_refs: Iterable[ActorHandle], storage_cfg: Dict[str, Any]) -> Any: # pragma: no cover
192+
return cls(sampler_refs, storage_cfg)
193+
194+
195+
if TYPE_CHECKING:
196+
EnvClient = EnvClientImpl
197+
ModelSampler = ModelSamplerImpl
198+
Sampler = SamplerImpl
199+
Controller = ControllerImpl
200+
else: # pragma: no cover - runtime actor binding
201+
EnvClient = ray.remote(num_cpus=0.05)(EnvClientImpl)
202+
ModelSampler = ray.remote(num_cpus=0.1)(ModelSamplerImpl)
203+
Sampler = ray.remote(num_cpus=0.2)(SamplerImpl)
204+
Controller = ray.remote(num_cpus=0.05)(ControllerImpl)
205+
172206

173207
def bootstrap_ray(env_hosts: List[str], *, num_samplers: int = 2):
174208
if not ray.is_initialized():

inference/storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import asyncio
44
import json
5+
import time
56
import uuid
67
from pathlib import Path
7-
import time
8-
from typing import Any, Dict, Iterable, List, Tuple, Callable
8+
from typing import Any, Callable, Dict, Iterable, List, Tuple
99

1010
from .config import StorageConfig
1111

pyproject.toml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
[tool.coverage.run]
2+
branch = true
3+
source = ["inference", "trainer", "envd"]
4+
omit = [
5+
"envd/*",
6+
"inference/serve.py",
7+
"inference/__init__.py",
8+
"trainer/model.py",
9+
"trainer/moe_deepspeed.py",
10+
"trainer/train.py",
11+
"trainer/train_deepspeed.py",
12+
"trainer/data.py",
13+
"trainer/__init__.py",
14+
]
15+
16+
[tool.coverage.report]
17+
show_missing = true
18+
skip_covered = true
19+
fail_under = 70
20+
21+
[tool.ruff]
22+
line-length = 100
23+
target-version = "py311"
24+
src = ["envd", "inference", "trainer", "tests"]
25+
exclude = ["envd/generated"]
26+
27+
[tool.ruff.lint]
28+
select = ["E", "F", "W", "I"]
29+
ignore = ["E203", "E266", "E501"]
30+
31+
[tool.ruff.format]
32+
quote-style = "double"
33+
34+
[tool.mypy]
35+
python_version = "3.11"
36+
ignore_missing_imports = true
37+
warn_unused_ignores = true
38+
warn_redundant_casts = true
39+
warn_unused_configs = true

requirements-test.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
pytest>=8.0.0
22
grpcio>=1.60.0
33
httpx>=0.27.0
4+
coverage[toml]>=7.6.0
5+
pytest-cov>=5.0.0
6+
ruff>=0.6.0
7+
mypy>=1.10.0

scripts/firecracker/health_check.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Iterable
55

66
import grpc
7-
87
from envd.generated import tools_pb2, tools_pb2_grpc
98

109

tests/test_controller_bridge.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,10 @@ def test_write_rollout_records_persists():
2323
records = asyncio.run(write_rollout_records(storage, [("prompt-2", {"latency": 1.0})], now=lambda: 456.0))
2424
assert storage.records == records
2525
assert records[0]["timestamp_s"] == 456.0
26+
27+
28+
def test_write_rollout_records_noop_when_empty():
29+
storage = InMemoryStorage()
30+
records = asyncio.run(write_rollout_records(storage, [], now=lambda: 789.0))
31+
assert records == []
32+
assert storage.records is None

0 commit comments

Comments
 (0)