Skip to content

Commit d37c54f

Browse files
WEIFENG2333claudehappy-otter
committed
Improve unit test coverage: 53 → 107 tests
test_mirror.py (7 → 10): auto-detect paths, all config fields, OSError test_installer.py (9 → 22): _ensure_uv platform logic, _uv_sync errors, env passing, idempotent create, cached uv path, macOS path test_client.py (17 → 35): lifecycle (start/stop/ensure_installed), error attributes, download_model, load_model kwargs, RPC id increment test_server.py (20 → 40): download_model hub variants, infer base64, transcribe base64, execute globals persist, custom return_var, serialize edge cases, CUDA cache, model replace, kwargs passthrough Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
1 parent f749543 commit d37c54f

File tree

4 files changed

+641
-37
lines changed

4 files changed

+641
-37
lines changed

tests/test_client.py

Lines changed: 184 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for the FunASR client."""
22

33
import json
4+
import subprocess
45
import tempfile
56
from http.server import HTTPServer, BaseHTTPRequestHandler
67
from pathlib import Path
@@ -26,7 +27,11 @@ def do_POST(self):
2627
if method == "health":
2728
result = {"status": "ok", "loaded_models": [], "device": "cpu", "cuda_available": False}
2829
elif method == "load_model":
29-
result = {"name": params.get("name", "default"), "status": "loaded"}
30+
result = {"name": params.get("name", "default"), "status": "loaded",
31+
"hub": params.get("hub"), "device": params.get("device"),
32+
"vad_model": params.get("vad_model")}
33+
result = {k: v for k, v in result.items() if v is not None}
34+
result["status"] = "loaded"
3035
elif method == "unload_model":
3136
result = {"name": params.get("name", "default"), "status": "unloaded"}
3237
elif method == "infer":
@@ -37,10 +42,16 @@ def do_POST(self):
3742
result = {"models": {}}
3843
elif method == "execute":
3944
result = {"output": "ok", "return_value": None}
45+
elif method == "download_model":
46+
result = {"model": params.get("model"), "path": "/tmp/model", "hub": params.get("hub", "ms")}
4047
elif method == "error_test":
4148
resp = {"jsonrpc": "2.0", "id": req_id, "error": {"code": -32000, "message": "test error"}}
4249
self._send_json(resp)
4350
return
51+
elif method == "error_with_data":
52+
resp = {"jsonrpc": "2.0", "id": req_id, "error": {"code": -32000, "message": "test error", "data": "traceback info"}}
53+
self._send_json(resp)
54+
return
4455
else:
4556
resp = {"jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Method not found: {method}"}}
4657
self._send_json(resp)
@@ -80,10 +91,118 @@ def client(mock_server):
8091
return c
8192

8293

94+
# ------------------------------------------------------------------
95+
# Lifecycle
96+
# ------------------------------------------------------------------
97+
98+
def test_init_defaults():
99+
with tempfile.TemporaryDirectory() as tmpdir:
100+
test_dir = Path(tmpdir) / "test_funasr"
101+
c = FunASR(runtime_dir=str(test_dir))
102+
assert c.runtime_dir == test_dir.resolve()
103+
assert c.port == 0
104+
assert c.host == "127.0.0.1"
105+
assert c._process is None
106+
107+
108+
def test_is_running_no_process():
109+
c = FunASR()
110+
assert c.is_running() is False
111+
112+
113+
def test_is_running_with_live_process():
114+
c = FunASR()
115+
mock_proc = MagicMock()
116+
mock_proc.poll.return_value = None # still running
117+
c._process = mock_proc
118+
assert c.is_running() is True
119+
120+
121+
def test_is_running_with_dead_process():
122+
c = FunASR()
123+
mock_proc = MagicMock()
124+
mock_proc.poll.return_value = 1 # exited
125+
c._process = mock_proc
126+
assert c.is_running() is False
127+
128+
129+
def test_stop_no_process():
130+
c = FunASR()
131+
c.stop() # should not raise
132+
133+
134+
def test_stop_graceful():
135+
"""stop() sends shutdown RPC then waits for process exit."""
136+
c = FunASR()
137+
mock_proc = MagicMock()
138+
c._process = mock_proc
139+
140+
with patch.object(c, "_rpc_call"):
141+
c.stop()
142+
143+
mock_proc.wait.assert_called_once()
144+
assert c._process is None
145+
146+
147+
def test_stop_force_kill():
148+
"""stop() force-kills when process doesn't exit gracefully."""
149+
c = FunASR()
150+
mock_proc = MagicMock()
151+
mock_proc.wait.side_effect = [subprocess.TimeoutExpired("cmd", 5), None]
152+
c._process = mock_proc
153+
154+
with patch.object(c, "_rpc_call"), \
155+
patch("platform.system", return_value="Linux"):
156+
c.stop()
157+
158+
mock_proc.send_signal.assert_called_once()
159+
assert c._process is None
160+
161+
162+
def test_ensure_installed_already_installed():
163+
c = FunASR()
164+
with patch.object(c.installer, "is_installed", return_value=True):
165+
result = c.ensure_installed()
166+
assert result is True
167+
168+
169+
def test_ensure_installed_fresh():
170+
c = FunASR()
171+
with patch.object(c.installer, "is_installed", return_value=False), \
172+
patch.object(c.installer, "install"):
173+
result = c.ensure_installed()
174+
assert result is False
175+
176+
177+
def test_start_no_uv():
178+
"""start() raises if uv is not found."""
179+
c = FunASR()
180+
with patch.object(c.installer, "get_uv_path", return_value=None):
181+
with pytest.raises(RuntimeError, match="uv not found"):
182+
c.start()
183+
184+
185+
def test_start_already_running():
186+
"""start() returns early if server is already running."""
187+
c = FunASR()
188+
mock_proc = MagicMock()
189+
mock_proc.poll.return_value = None # alive
190+
c._process = mock_proc
191+
c.port = 1234
192+
193+
result = c.start()
194+
assert result == 1234
195+
196+
197+
# ------------------------------------------------------------------
198+
# RPC API via mock server
199+
# ------------------------------------------------------------------
200+
83201
def test_health(client):
84202
result = client.health()
85203
assert result["status"] == "ok"
86204
assert "loaded_models" in result
205+
assert "cuda_available" in result
87206

88207

89208
def test_load_model(client):
@@ -97,11 +216,31 @@ def test_load_model_with_name(client):
97216
assert result["name"] == "my_model"
98217

99218

219+
def test_load_model_with_kwargs(client):
220+
"""Extra kwargs (hub, device, etc.) are passed through."""
221+
result = client.load_model(model="test-model", hub="hf", device="cpu")
222+
assert result["status"] == "loaded"
223+
assert result["hub"] == "hf"
224+
assert result["device"] == "cpu"
225+
226+
227+
def test_load_model_with_vad_model(client):
228+
"""vad_model parameter is passed through."""
229+
result = client.load_model(model="test-model", vad_model="fsmn-vad")
230+
assert result["vad_model"] == "fsmn-vad"
231+
232+
100233
def test_unload_model(client):
101234
result = client.unload_model()
102235
assert result["status"] == "unloaded"
103236

104237

238+
def test_unload_model_by_name(client):
239+
result = client.unload_model(name="my_model")
240+
assert result["name"] == "my_model"
241+
assert result["status"] == "unloaded"
242+
243+
105244
def test_infer(client):
106245
result = client.infer(input="test.wav")
107246
assert len(result) == 1
@@ -113,6 +252,12 @@ def test_infer_with_bytes(client):
113252
assert len(result) == 1
114253

115254

255+
def test_infer_with_name(client):
256+
"""infer() passes model name."""
257+
result = client.infer(input="test.wav", name="vad")
258+
assert len(result) == 1
259+
260+
116261
def test_infer_no_input(client):
117262
with pytest.raises(ValueError, match="Either 'input' or 'input_bytes'"):
118263
client.infer()
@@ -144,31 +289,52 @@ def test_execute(client):
144289
assert result["output"] == "ok"
145290

146291

292+
def test_download_model(client):
293+
result = client.download_model(model="iic/SenseVoiceSmall")
294+
assert result["model"] == "iic/SenseVoiceSmall"
295+
assert result["path"] == "/tmp/model"
296+
297+
298+
def test_download_model_with_hub(client):
299+
result = client.download_model(model="test-model", hub="hf")
300+
assert result["hub"] == "hf"
301+
302+
303+
# ------------------------------------------------------------------
304+
# Error handling
305+
# ------------------------------------------------------------------
306+
147307
def test_server_error(client):
148308
with pytest.raises(ServerError, match="test error"):
149309
client._rpc_call("error_test", {})
150310

151311

152-
def test_connection_error():
153-
c = FunASR(port=1, host="127.0.0.1")
154-
with pytest.raises(ConnectionError):
155-
c._rpc_call("health", {}, timeout=1)
312+
def test_server_error_with_data(client):
313+
"""ServerError includes data field from server response."""
314+
with pytest.raises(ServerError) as exc_info:
315+
client._rpc_call("error_with_data", {})
316+
assert exc_info.value.code == -32000
317+
assert exc_info.value.data == "traceback info"
156318

157319

158-
def test_context_manager_init():
159-
with tempfile.TemporaryDirectory() as tmpdir:
160-
test_dir = Path(tmpdir) / "test_funasr"
161-
c = FunASR(runtime_dir=str(test_dir))
162-
assert c.runtime_dir == test_dir.resolve()
163-
assert c.port == 0
164-
assert c.host == "127.0.0.1"
320+
def test_server_error_attributes():
321+
"""ServerError stores code and data."""
322+
err = ServerError(-32000, "test message", "extra data")
323+
assert err.code == -32000
324+
assert err.data == "extra data"
325+
assert str(err) == "test message"
165326

166327

167-
def test_is_running_no_process():
168-
c = FunASR()
169-
assert c.is_running() is False
328+
def test_connection_error():
329+
c = FunASR(port=1, host="127.0.0.1")
330+
with pytest.raises(ConnectionError):
331+
c._rpc_call("health", {}, timeout=1)
170332

171333

172-
def test_stop_no_process():
173-
c = FunASR()
174-
c.stop() # should not raise
334+
def test_rpc_id_increments(client):
335+
"""Each RPC call increments the request ID."""
336+
initial = client._rpc_id
337+
client.health()
338+
assert client._rpc_id == initial + 1
339+
client.health()
340+
assert client._rpc_id == initial + 2

0 commit comments

Comments
 (0)