Skip to content

Commit ae04109

Browse files
WEIFENG2333claudehappy-otter
committed
feat: add inference progress callback support
Server stores progress in _progress dict during inference, exposes get_progress RPC method. Client supports progress_callback parameter on Model.infer() with threaded polling. Blocking RPC methods now run in executor so progress queries can be served concurrently. Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
1 parent 3d251a4 commit ae04109

File tree

4 files changed

+128
-7
lines changed

4 files changed

+128
-7
lines changed

src/funasr_server/client.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def infer(
8181
merge_vad: bool = None,
8282
merge_length_s: float = None,
8383
output_timestamp: bool = None,
84+
progress_callback=None,
8485
**kwargs,
8586
) -> list:
8687
"""Run inference on this model.
@@ -96,12 +97,13 @@ def infer(
9697
merge_vad: Merge short VAD segments.
9798
merge_length_s: Max merge length in seconds.
9899
output_timestamp: Include timestamps in output.
100+
progress_callback: Optional callable(current, total) for progress.
99101
**kwargs: Additional generate() parameters.
100102
101103
Returns:
102104
List of result dicts.
103105
"""
104-
return self._client.infer(
106+
infer_kwargs = dict(
105107
audio=audio, text=text, audio_bytes=audio_bytes,
106108
name=self.name,
107109
language=language, use_itn=use_itn, batch_size=batch_size,
@@ -110,6 +112,34 @@ def infer(
110112
**kwargs,
111113
)
112114

115+
if progress_callback is None:
116+
return self._client.infer(**infer_kwargs)
117+
118+
import threading
119+
120+
result_box = [None]
121+
error_box = [None]
122+
123+
def _do_infer():
124+
try:
125+
result_box[0] = self._client.infer(**infer_kwargs)
126+
except Exception as e:
127+
error_box[0] = e
128+
129+
t = threading.Thread(target=_do_infer)
130+
t.start()
131+
while t.is_alive():
132+
try:
133+
p = self.get_progress()
134+
progress_callback(p["current"], p["total"])
135+
except Exception:
136+
pass
137+
t.join(timeout=0.5)
138+
139+
if error_box[0] is not None:
140+
raise error_box[0]
141+
return result_box[0]
142+
113143
def transcribe(
114144
self,
115145
audio: str = None,
@@ -119,6 +149,10 @@ def transcribe(
119149
"""Transcribe audio — convenience alias for infer()."""
120150
return self.infer(audio=audio, audio_bytes=audio_bytes, **kwargs)
121151

152+
def get_progress(self) -> dict:
153+
"""Get inference progress. Returns {"current": int, "total": int}."""
154+
return self._client.get_progress(name=self.name)
155+
122156
def unload(self) -> dict:
123157
"""Unload this model from the server."""
124158
return self._client.unload_model(name=self.name)
@@ -367,6 +401,10 @@ def unload_model(self, name: str) -> dict:
367401
"""Unload a model by name. Prefer ``model.unload()`` instead."""
368402
return self._rpc_call("unload_model", {"name": name})
369403

404+
def get_progress(self, name: str) -> dict:
405+
"""Get inference progress. Returns {"current": int, "total": int}."""
406+
return self._rpc_call("get_progress", {"name": name}, timeout=5)
407+
370408
def infer(
371409
self,
372410
name: str,

src/funasr_server/runtime_template/server.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
import argparse
12+
import asyncio
1213
import base64
1314
import io
1415
import json
@@ -32,6 +33,7 @@
3233

3334
_models: dict = {} # name -> AutoModel instance
3435
_model_kwargs: dict = {} # name -> kwargs used to create it
36+
_progress: dict = {} # name -> {"current": int, "total": int}
3537
_exec_globals: dict = {"__builtins__": __builtins__} # shared exec namespace
3638

3739

@@ -184,9 +186,16 @@ def rpc_infer(params: dict) -> dict:
184186
if input_data is None:
185187
raise ValueError("'input' or 'input_base64' is required")
186188

189+
def _on_progress(current, total):
190+
_progress[name] = {"current": current, "total": total}
191+
192+
_progress[name] = {"current": 0, "total": 0}
187193
try:
188-
result = model.generate(input=input_data, **generate_kwargs)
194+
result = model.generate(
195+
input=input_data, progress_callback=_on_progress, **generate_kwargs
196+
)
189197
finally:
198+
_progress.pop(name, None)
190199
if tmp_file and os.path.exists(tmp_file):
191200
os.unlink(tmp_file)
192201

@@ -284,6 +293,19 @@ def rpc_list_models(params: dict) -> dict:
284293
}
285294

286295

296+
def rpc_get_progress(params: dict) -> dict:
297+
"""Get inference progress for a model.
298+
299+
Params:
300+
name (str): Model cache key (default: "default")
301+
302+
Returns:
303+
{"current": int, "total": int} — 0/0 if not running.
304+
"""
305+
name = params.get("name", "default")
306+
return _progress.get(name, {"current": 0, "total": 0})
307+
308+
287309
def rpc_shutdown(params: dict) -> dict:
288310
"""Gracefully shut down the server."""
289311
logger.info("Shutdown requested")
@@ -300,9 +322,14 @@ def rpc_shutdown(params: dict) -> dict:
300322
"execute": rpc_execute,
301323
"download_model": rpc_download_model,
302324
"list_models": rpc_list_models,
325+
"get_progress": rpc_get_progress,
303326
"shutdown": rpc_shutdown,
304327
}
305328

329+
# Methods that block for a long time and must run in a thread pool
330+
# so other requests (like get_progress) can be handled concurrently.
331+
_BLOCKING_METHODS = {"infer", "transcribe", "load_model", "download_model"}
332+
306333

307334
# ---------------------------------------------------------------------------
308335
# HTTP handler
@@ -333,7 +360,11 @@ async def handle_rpc(request: Request):
333360
return JSONResponse(_error(req_id, -32601, f"Method not found: {method}"))
334361

335362
try:
336-
result = handler(params)
363+
if method in _BLOCKING_METHODS:
364+
loop = asyncio.get_event_loop()
365+
result = await loop.run_in_executor(None, handler, params)
366+
else:
367+
result = handler(params)
337368
return JSONResponse(_ok(req_id, result))
338369
except Exception as e:
339370
logger.exception(f"Error in method '{method}'")

tests/test_client.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def do_POST(self):
3838
result = {"models": {}}
3939
elif method == "execute":
4040
result = {"output": "ok", "return_value": None}
41+
elif method == "get_progress":
42+
result = {"current": 3, "total": 10}
4143
elif method == "download_model":
4244
result = {"model": params.get("model"), "path": "/tmp/model", "hub": params.get("hub", "ms")}
4345
elif method == "error_test":
@@ -417,3 +419,45 @@ def test_rpc_rejects_no_result(client):
417419
def test_rpc_rejects_malformed_error(client):
418420
with pytest.raises(ConnectionError, match="Malformed JSON-RPC error"):
419421
client._rpc_call("malformed_error", {})
422+
423+
424+
# ------------------------------------------------------------------
425+
# Progress
426+
# ------------------------------------------------------------------
427+
428+
def test_get_progress(client):
429+
result = client.get_progress(name="test")
430+
assert result["current"] == 3
431+
assert result["total"] == 10
432+
433+
434+
def test_model_get_progress(client):
435+
with patch("funasr_server.client.get_hub", return_value="ms"):
436+
model = client.load_model("test-model", name="my_model")
437+
result = model.get_progress()
438+
assert result["current"] == 3
439+
assert result["total"] == 10
440+
441+
442+
def test_model_infer_with_progress_callback(client):
443+
with patch("funasr_server.client.get_hub", return_value="ms"):
444+
model = client.load_model("test-model")
445+
446+
calls = []
447+
448+
def on_progress(current, total):
449+
calls.append((current, total))
450+
451+
result = model.infer(audio="test.wav", progress_callback=on_progress)
452+
assert result[0]["text"] == "hello world"
453+
# progress_callback should have been called at least once
454+
assert len(calls) >= 1
455+
assert calls[0] == (3, 10)
456+
457+
458+
def test_model_infer_without_progress_callback(client):
459+
with patch("funasr_server.client.get_hub", return_value="ms"):
460+
model = client.load_model("test-model")
461+
# Without progress_callback — should work exactly as before
462+
result = model.infer(audio="test.wav")
463+
assert result[0]["text"] == "hello world"

tests/test_server.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ def test_rpc_infer():
213213

214214
result = server.rpc_infer({"input": "test.wav"})
215215
assert result["results"] == [{"key": "test", "text": "hello"}]
216-
mock_model.generate.assert_called_once_with(input="test.wav")
216+
call_kwargs = mock_model.generate.call_args[1]
217+
assert call_kwargs["input"] == "test.wav"
218+
assert "progress_callback" in call_kwargs
217219

218220

219221
def test_rpc_infer_by_name():
@@ -233,7 +235,11 @@ def test_rpc_infer_passes_extra_kwargs():
233235
server._models["default"] = mock_model
234236

235237
server.rpc_infer({"input": "test.wav", "language": "zh", "use_itn": True})
236-
mock_model.generate.assert_called_once_with(input="test.wav", language="zh", use_itn=True)
238+
call_kwargs = mock_model.generate.call_args[1]
239+
assert call_kwargs["input"] == "test.wav"
240+
assert call_kwargs["language"] == "zh"
241+
assert call_kwargs["use_itn"] is True
242+
assert "progress_callback" in call_kwargs
237243

238244

239245
def test_rpc_infer_model_not_loaded():
@@ -276,7 +282,9 @@ def test_rpc_transcribe_maps_params():
276282

277283
result = server.rpc_transcribe({"audio": "test.wav"})
278284
assert result["results"] == [{"key": "test", "text": "hello"}]
279-
mock_model.generate.assert_called_once_with(input="test.wav")
285+
call_kwargs = mock_model.generate.call_args[1]
286+
assert call_kwargs["input"] == "test.wav"
287+
assert "progress_callback" in call_kwargs
280288

281289

282290
def test_rpc_transcribe_maps_audio_base64():
@@ -411,5 +419,5 @@ def test_rpc_list_models_multiple():
411419

412420
def test_methods_dispatch_table():
413421
expected = {"health", "load_model", "unload_model", "infer", "transcribe",
414-
"execute", "download_model", "list_models", "shutdown"}
422+
"execute", "download_model", "list_models", "get_progress", "shutdown"}
415423
assert expected == set(server._METHODS.keys())

0 commit comments

Comments
 (0)