Skip to content

Commit cc5e2a5

Browse files
WEIFENG2333claudehappy-otter
committed
feat: support audio+text input for forced alignment and auto-sync server.py
- Support simultaneous audio + text input in infer() for models like fa-zh that require both (forced alignment). Server constructs tuple input with data_type=("sound", "text") automatically. - Auto-sync server.py from template on start() by comparing file contents, so package upgrades take effect without re-install. Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
1 parent 904e25d commit cc5e2a5

File tree

3 files changed

+86
-1
lines changed

3 files changed

+86
-1
lines changed

src/funasr_server/client.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import logging
2828
import os
2929
import platform
30+
import shutil
3031
import signal
3132
import subprocess
3233
import time
@@ -39,6 +40,8 @@
3940
from funasr_server.mirror import get_hub
4041
from funasr_server.models import resolve_model_id
4142

43+
_TEMPLATE_DIR = Path(__file__).parent / "runtime_template"
44+
4245
logger = logging.getLogger(__name__)
4346

4447

@@ -231,6 +234,8 @@ def start(self, timeout: float = 60) -> int:
231234
logger.info(f"Server already running (pid={self._process.pid})")
232235
return self.port
233236

237+
self._sync_server_py()
238+
234239
uv_path = self.installer.get_uv_path()
235240
if not uv_path:
236241
raise RuntimeError("uv not found. Call ensure_installed() first.")
@@ -303,6 +308,22 @@ def is_running(self) -> bool:
303308
return False
304309
return self._process.poll() is None
305310

311+
def _sync_server_py(self):
312+
"""Sync server.py from template to runtime dir if outdated.
313+
314+
Called automatically before each start() so that upgrading the
315+
funasr-server package immediately picks up server-side changes
316+
without requiring a full re-install.
317+
"""
318+
template = _TEMPLATE_DIR / "server.py"
319+
runtime = self.runtime_dir / "server.py"
320+
if not template.exists() or not runtime.exists():
321+
return
322+
if runtime.read_bytes() == template.read_bytes():
323+
return
324+
shutil.copy2(template, runtime)
325+
logger.info("Synced server.py from template (updated)")
326+
306327
# ------------------------------------------------------------------
307328
# High-level API
308329
# ------------------------------------------------------------------
@@ -423,11 +444,15 @@ def infer(
423444

424445
if audio_bytes is not None:
425446
params["input_base64"] = base64.b64encode(audio_bytes).decode()
447+
if text is not None:
448+
params["text"] = text
426449
elif audio is not None:
427450
if os.path.exists(audio):
428451
params["input"] = str(Path(audio).resolve())
429452
else:
430453
params["input"] = audio
454+
if text is not None:
455+
params["text"] = text
431456
elif text is not None:
432457
params["input"] = text
433458
else:

src/funasr_server/runtime_template/server.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,10 @@ def rpc_infer(params: dict) -> dict:
170170
input_base64 = params.get("input_base64")
171171
tmp_file = None
172172

173+
text = params.get("text")
174+
173175
# Build generate kwargs: everything except control params
174-
_control_keys = {"name", "input", "input_base64", "audio_format"}
176+
_control_keys = {"name", "input", "input_base64", "audio_format", "text"}
175177
generate_kwargs = {k: v for k, v in params.items() if k not in _control_keys}
176178

177179
if input_base64:
@@ -186,6 +188,12 @@ def rpc_infer(params: dict) -> dict:
186188
if input_data is None:
187189
raise ValueError("'input' or 'input_base64' is required")
188190

191+
# For models that need both audio and text (e.g. fa-zh forced alignment),
192+
# pass input as a tuple with data_type hint.
193+
if text is not None and input_data is not None:
194+
input_data = (input_data, text)
195+
generate_kwargs["data_type"] = ("sound", "text")
196+
189197
def _on_progress(current, total):
190198
_progress[name] = {"current": current, "total": total}
191199

tests/test_client.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,3 +461,55 @@ def test_model_infer_without_progress_callback(client):
461461
# Without progress_callback — should work exactly as before
462462
result = model.infer(audio="test.wav")
463463
assert result[0]["text"] == "hello world"
464+
465+
466+
# ------------------------------------------------------------------
467+
# server.py sync tests
468+
# ------------------------------------------------------------------
469+
470+
def test_sync_server_py_copies_when_different(tmp_path):
471+
"""_sync_server_py copies template when runtime file differs."""
472+
runtime = tmp_path / "runtime"
473+
runtime.mkdir()
474+
server_py = runtime / "server.py"
475+
server_py.write_text("old content")
476+
477+
asr = FunASR(runtime_dir=str(runtime))
478+
479+
template = Path(__file__).parent.parent / "src" / "funasr_server" / "runtime_template" / "server.py"
480+
template_content = template.read_bytes()
481+
482+
asr._sync_server_py()
483+
484+
assert server_py.read_bytes() == template_content
485+
486+
487+
def test_sync_server_py_skips_when_same(tmp_path):
488+
"""_sync_server_py does nothing when files are identical."""
489+
runtime = tmp_path / "runtime"
490+
runtime.mkdir()
491+
server_py = runtime / "server.py"
492+
493+
template = Path(__file__).parent.parent / "src" / "funasr_server" / "runtime_template" / "server.py"
494+
template_content = template.read_bytes()
495+
server_py.write_bytes(template_content)
496+
497+
mtime_before = server_py.stat().st_mtime
498+
499+
asr = FunASR(runtime_dir=str(runtime))
500+
asr._sync_server_py()
501+
502+
# File should not have been touched (same mtime)
503+
assert server_py.stat().st_mtime == mtime_before
504+
505+
506+
def test_sync_server_py_skips_when_no_runtime(tmp_path):
507+
"""_sync_server_py does nothing when runtime server.py doesn't exist yet."""
508+
runtime = tmp_path / "runtime"
509+
runtime.mkdir()
510+
# No server.py in runtime dir
511+
512+
asr = FunASR(runtime_dir=str(runtime))
513+
asr._sync_server_py() # should not raise
514+
515+
assert not (runtime / "server.py").exists()

0 commit comments

Comments
 (0)