11"""Tests for the FunASR client."""
22
33import json
4+ import subprocess
45import tempfile
56from http .server import HTTPServer , BaseHTTPRequestHandler
67from pathlib import Path
@@ -26,7 +27,11 @@ def do_POST(self):
2627 if method == "health" :
2728 result = {"status" : "ok" , "loaded_models" : [], "device" : "cpu" , "cuda_available" : False }
2829 elif method == "load_model" :
29- result = {"name" : params .get ("name" , "default" ), "status" : "loaded" }
30+ result = {"name" : params .get ("name" , "default" ), "status" : "loaded" ,
31+ "hub" : params .get ("hub" ), "device" : params .get ("device" ),
32+ "vad_model" : params .get ("vad_model" )}
33+ result = {k : v for k , v in result .items () if v is not None }
34+ result ["status" ] = "loaded"
3035 elif method == "unload_model" :
3136 result = {"name" : params .get ("name" , "default" ), "status" : "unloaded" }
3237 elif method == "infer" :
@@ -37,10 +42,16 @@ def do_POST(self):
3742 result = {"models" : {}}
3843 elif method == "execute" :
3944 result = {"output" : "ok" , "return_value" : None }
45+ elif method == "download_model" :
46+ result = {"model" : params .get ("model" ), "path" : "/tmp/model" , "hub" : params .get ("hub" , "ms" )}
4047 elif method == "error_test" :
4148 resp = {"jsonrpc" : "2.0" , "id" : req_id , "error" : {"code" : - 32000 , "message" : "test error" }}
4249 self ._send_json (resp )
4350 return
51+ elif method == "error_with_data" :
52+ resp = {"jsonrpc" : "2.0" , "id" : req_id , "error" : {"code" : - 32000 , "message" : "test error" , "data" : "traceback info" }}
53+ self ._send_json (resp )
54+ return
4455 else :
4556 resp = {"jsonrpc" : "2.0" , "id" : req_id , "error" : {"code" : - 32601 , "message" : f"Method not found: { method } " }}
4657 self ._send_json (resp )
@@ -80,10 +91,118 @@ def client(mock_server):
8091 return c
8192
8293
94+ # ------------------------------------------------------------------
95+ # Lifecycle
96+ # ------------------------------------------------------------------
97+
98+ def test_init_defaults ():
99+ with tempfile .TemporaryDirectory () as tmpdir :
100+ test_dir = Path (tmpdir ) / "test_funasr"
101+ c = FunASR (runtime_dir = str (test_dir ))
102+ assert c .runtime_dir == test_dir .resolve ()
103+ assert c .port == 0
104+ assert c .host == "127.0.0.1"
105+ assert c ._process is None
106+
107+
108+ def test_is_running_no_process ():
109+ c = FunASR ()
110+ assert c .is_running () is False
111+
112+
113+ def test_is_running_with_live_process ():
114+ c = FunASR ()
115+ mock_proc = MagicMock ()
116+ mock_proc .poll .return_value = None # still running
117+ c ._process = mock_proc
118+ assert c .is_running () is True
119+
120+
121+ def test_is_running_with_dead_process ():
122+ c = FunASR ()
123+ mock_proc = MagicMock ()
124+ mock_proc .poll .return_value = 1 # exited
125+ c ._process = mock_proc
126+ assert c .is_running () is False
127+
128+
129+ def test_stop_no_process ():
130+ c = FunASR ()
131+ c .stop () # should not raise
132+
133+
134+ def test_stop_graceful ():
135+ """stop() sends shutdown RPC then waits for process exit."""
136+ c = FunASR ()
137+ mock_proc = MagicMock ()
138+ c ._process = mock_proc
139+
140+ with patch .object (c , "_rpc_call" ):
141+ c .stop ()
142+
143+ mock_proc .wait .assert_called_once ()
144+ assert c ._process is None
145+
146+
147+ def test_stop_force_kill ():
148+ """stop() force-kills when process doesn't exit gracefully."""
149+ c = FunASR ()
150+ mock_proc = MagicMock ()
151+ mock_proc .wait .side_effect = [subprocess .TimeoutExpired ("cmd" , 5 ), None ]
152+ c ._process = mock_proc
153+
154+ with patch .object (c , "_rpc_call" ), \
155+ patch ("platform.system" , return_value = "Linux" ):
156+ c .stop ()
157+
158+ mock_proc .send_signal .assert_called_once ()
159+ assert c ._process is None
160+
161+
162+ def test_ensure_installed_already_installed ():
163+ c = FunASR ()
164+ with patch .object (c .installer , "is_installed" , return_value = True ):
165+ result = c .ensure_installed ()
166+ assert result is True
167+
168+
169+ def test_ensure_installed_fresh ():
170+ c = FunASR ()
171+ with patch .object (c .installer , "is_installed" , return_value = False ), \
172+ patch .object (c .installer , "install" ):
173+ result = c .ensure_installed ()
174+ assert result is False
175+
176+
177+ def test_start_no_uv ():
178+ """start() raises if uv is not found."""
179+ c = FunASR ()
180+ with patch .object (c .installer , "get_uv_path" , return_value = None ):
181+ with pytest .raises (RuntimeError , match = "uv not found" ):
182+ c .start ()
183+
184+
185+ def test_start_already_running ():
186+ """start() returns early if server is already running."""
187+ c = FunASR ()
188+ mock_proc = MagicMock ()
189+ mock_proc .poll .return_value = None # alive
190+ c ._process = mock_proc
191+ c .port = 1234
192+
193+ result = c .start ()
194+ assert result == 1234
195+
196+
197+ # ------------------------------------------------------------------
198+ # RPC API via mock server
199+ # ------------------------------------------------------------------
200+
83201def test_health (client ):
84202 result = client .health ()
85203 assert result ["status" ] == "ok"
86204 assert "loaded_models" in result
205+ assert "cuda_available" in result
87206
88207
89208def test_load_model (client ):
@@ -97,11 +216,31 @@ def test_load_model_with_name(client):
97216 assert result ["name" ] == "my_model"
98217
99218
219+ def test_load_model_with_kwargs (client ):
220+ """Extra kwargs (hub, device, etc.) are passed through."""
221+ result = client .load_model (model = "test-model" , hub = "hf" , device = "cpu" )
222+ assert result ["status" ] == "loaded"
223+ assert result ["hub" ] == "hf"
224+ assert result ["device" ] == "cpu"
225+
226+
227+ def test_load_model_with_vad_model (client ):
228+ """vad_model parameter is passed through."""
229+ result = client .load_model (model = "test-model" , vad_model = "fsmn-vad" )
230+ assert result ["vad_model" ] == "fsmn-vad"
231+
232+
100233def test_unload_model (client ):
101234 result = client .unload_model ()
102235 assert result ["status" ] == "unloaded"
103236
104237
238+ def test_unload_model_by_name (client ):
239+ result = client .unload_model (name = "my_model" )
240+ assert result ["name" ] == "my_model"
241+ assert result ["status" ] == "unloaded"
242+
243+
105244def test_infer (client ):
106245 result = client .infer (input = "test.wav" )
107246 assert len (result ) == 1
@@ -113,6 +252,12 @@ def test_infer_with_bytes(client):
113252 assert len (result ) == 1
114253
115254
255+ def test_infer_with_name (client ):
256+ """infer() passes model name."""
257+ result = client .infer (input = "test.wav" , name = "vad" )
258+ assert len (result ) == 1
259+
260+
116261def test_infer_no_input (client ):
117262 with pytest .raises (ValueError , match = "Either 'input' or 'input_bytes'" ):
118263 client .infer ()
@@ -144,31 +289,52 @@ def test_execute(client):
144289 assert result ["output" ] == "ok"
145290
146291
292+ def test_download_model (client ):
293+ result = client .download_model (model = "iic/SenseVoiceSmall" )
294+ assert result ["model" ] == "iic/SenseVoiceSmall"
295+ assert result ["path" ] == "/tmp/model"
296+
297+
298+ def test_download_model_with_hub (client ):
299+ result = client .download_model (model = "test-model" , hub = "hf" )
300+ assert result ["hub" ] == "hf"
301+
302+
303+ # ------------------------------------------------------------------
304+ # Error handling
305+ # ------------------------------------------------------------------
306+
147307def test_server_error (client ):
148308 with pytest .raises (ServerError , match = "test error" ):
149309 client ._rpc_call ("error_test" , {})
150310
151311
152- def test_connection_error ():
153- c = FunASR (port = 1 , host = "127.0.0.1" )
154- with pytest .raises (ConnectionError ):
155- c ._rpc_call ("health" , {}, timeout = 1 )
312+ def test_server_error_with_data (client ):
313+ """ServerError includes data field from server response."""
314+ with pytest .raises (ServerError ) as exc_info :
315+ client ._rpc_call ("error_with_data" , {})
316+ assert exc_info .value .code == - 32000
317+ assert exc_info .value .data == "traceback info"
156318
157319
158- def test_context_manager_init ():
159- with tempfile .TemporaryDirectory () as tmpdir :
160- test_dir = Path (tmpdir ) / "test_funasr"
161- c = FunASR (runtime_dir = str (test_dir ))
162- assert c .runtime_dir == test_dir .resolve ()
163- assert c .port == 0
164- assert c .host == "127.0.0.1"
320+ def test_server_error_attributes ():
321+ """ServerError stores code and data."""
322+ err = ServerError (- 32000 , "test message" , "extra data" )
323+ assert err .code == - 32000
324+ assert err .data == "extra data"
325+ assert str (err ) == "test message"
165326
166327
167- def test_is_running_no_process ():
168- c = FunASR ()
169- assert c .is_running () is False
328+ def test_connection_error ():
329+ c = FunASR (port = 1 , host = "127.0.0.1" )
330+ with pytest .raises (ConnectionError ):
331+ c ._rpc_call ("health" , {}, timeout = 1 )
170332
171333
172- def test_stop_no_process ():
173- c = FunASR ()
174- c .stop () # should not raise
334+ def test_rpc_id_increments (client ):
335+ """Each RPC call increments the request ID."""
336+ initial = client ._rpc_id
337+ client .health ()
338+ assert client ._rpc_id == initial + 1
339+ client .health ()
340+ assert client ._rpc_id == initial + 2
0 commit comments