Skip to content

Commit 9376555

Browse files
WEIFENG2333claudehappy-otter
committed
refactor: load_model returns Model handle object
- Add Model class with infer(), transcribe(), unload(), __call__() - load_model() returns Model instead of dict, auto-generates name - Remove FunASR.transcribe() — use model.transcribe() instead - Simplify FunASR.infer() to internal API used by Model - Clean up tests to use Model-based API throughout Usage: model = asr.load_model("SenseVoiceSmall", vad_model="fsmn-vad") result = model.infer(audio="test.wav") result = model(audio="test.wav") # shorthand model.unload() Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
1 parent 0cad59c commit 9376555

File tree

4 files changed

+289
-526
lines changed

4 files changed

+289
-526
lines changed

src/funasr_server/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from funasr_server.client import FunASR, ServerError
1+
from funasr_server.client import FunASR, Model, ServerError
22
from funasr_server.mirror import detect_region, get_hub
33
from funasr_server.models import resolve_model_id, list_available_models
44

55
__all__ = [
66
"FunASR",
7+
"Model",
78
"ServerError",
89
"detect_region",
910
"get_hub",

src/funasr_server/client.py

Lines changed: 122 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,20 @@
66
Usage:
77
from funasr_server import FunASR
88
9-
asr = FunASR()
10-
asr.ensure_installed()
11-
asr.start()
9+
with FunASR() as asr:
10+
# load_model() returns a Model handle
11+
model = asr.load_model("SenseVoiceSmall", vad_model="fsmn-vad")
12+
result = model.infer(audio="audio.wav")
13+
result = model(audio="audio.wav") # shorthand
1214
13-
# ASR — model name auto-resolved to correct hub
14-
asr.load_model(model="SenseVoiceSmall")
15-
result = asr.infer(audio="audio.wav", language="zh", use_itn=True)
15+
# Multiple models
16+
vad = asr.load_model("fsmn-vad")
17+
punc = asr.load_model("ct-punc")
1618
17-
# VAD (standalone)
18-
asr.load_model(model="fsmn-vad", name="vad")
19-
result = asr.infer(audio="audio.wav", name="vad")
19+
segments = vad(audio="audio.wav")
20+
text = punc(text="你好世界今天天气真好")
2021
21-
# ASR + VAD pipeline
22-
asr.load_model(model="SenseVoiceSmall", vad_model="fsmn-vad", name="asr_vad")
23-
result = asr.infer(audio="audio.wav", name="asr_vad")
24-
25-
# Punctuation model
26-
asr.load_model(model="ct-punc", name="punc")
27-
result = asr.infer(text="你好世界今天天气真好", name="punc")
28-
29-
asr.stop()
22+
model.unload()
3023
"""
3124

3225
import base64
@@ -57,6 +50,93 @@ def __init__(self, code: int, message: str, data: str = None):
5750
super().__init__(message)
5851

5952

53+
class Model:
54+
"""Handle to a loaded model on the server.
55+
56+
Returned by ``FunASR.load_model()``. Provides ``infer()`` and
57+
``unload()`` without needing to pass a name string around.
58+
59+
Usage::
60+
61+
model = asr.load_model("SenseVoiceSmall", vad_model="fsmn-vad")
62+
result = model.infer(audio="test.wav")
63+
# or just call it directly
64+
result = model(audio="test.wav")
65+
model.unload()
66+
"""
67+
68+
def __init__(self, client: "FunASR", name: str):
69+
self._client = client
70+
self.name = name
71+
72+
def infer(
73+
self,
74+
audio: str = None,
75+
text: str = None,
76+
audio_bytes: bytes = None,
77+
language: str = None,
78+
use_itn: bool = None,
79+
batch_size: int = None,
80+
hotword: str = None,
81+
merge_vad: bool = None,
82+
merge_length_s: float = None,
83+
output_timestamp: bool = None,
84+
**kwargs,
85+
) -> list:
86+
"""Run inference on this model.
87+
88+
Args:
89+
audio: Path to audio file.
90+
text: Text string (for punctuation models).
91+
audio_bytes: Raw audio bytes.
92+
language: Language code (e.g. "zh", "en", "ja").
93+
use_itn: Enable inverse text normalization.
94+
batch_size: Inference batch size.
95+
hotword: Hotword string.
96+
merge_vad: Merge short VAD segments.
97+
merge_length_s: Max merge length in seconds.
98+
output_timestamp: Include timestamps in output.
99+
**kwargs: Additional generate() parameters.
100+
101+
Returns:
102+
List of result dicts.
103+
"""
104+
return self._client.infer(
105+
audio=audio, text=text, audio_bytes=audio_bytes,
106+
name=self.name,
107+
language=language, use_itn=use_itn, batch_size=batch_size,
108+
hotword=hotword, merge_vad=merge_vad,
109+
merge_length_s=merge_length_s, output_timestamp=output_timestamp,
110+
**kwargs,
111+
)
112+
113+
def transcribe(
114+
self,
115+
audio: str = None,
116+
audio_bytes: bytes = None,
117+
**kwargs,
118+
) -> list:
119+
"""Transcribe audio — convenience alias for infer()."""
120+
return self.infer(audio=audio, audio_bytes=audio_bytes, **kwargs)
121+
122+
def unload(self) -> dict:
123+
"""Unload this model from the server."""
124+
return self._client.unload_model(name=self.name)
125+
126+
def __call__(
127+
self,
128+
audio: str = None,
129+
text: str = None,
130+
audio_bytes: bytes = None,
131+
**kwargs,
132+
) -> list:
133+
"""Shorthand for infer()."""
134+
return self.infer(audio=audio, text=text, audio_bytes=audio_bytes, **kwargs)
135+
136+
def __repr__(self):
137+
return f"Model(name={self.name!r})"
138+
139+
60140
class FunASR:
61141
"""Client for the FunASR inference server.
62142
@@ -81,6 +161,7 @@ def __init__(
81161
self.installer = Installer(str(self.runtime_dir))
82162
self._process: Optional[subprocess.Popen] = None
83163
self._rpc_id = 0
164+
self._model_counter = 0
84165

85166
# ------------------------------------------------------------------
86167
# Lifecycle
@@ -199,7 +280,7 @@ def health(self) -> dict:
199280
def load_model(
200281
self,
201282
model: str,
202-
name: str = "default",
283+
name: str = None,
203284
vad_model: str = None,
204285
punc_model: str = None,
205286
spk_model: str = None,
@@ -210,11 +291,15 @@ def load_model(
210291
fp16: bool = None,
211292
disable_update: bool = None,
212293
**kwargs,
213-
) -> dict:
294+
) -> "Model":
214295
"""Load any FunASR model via AutoModel.
215296
216-
Works with ALL model types: ASR, VAD, punctuation, speaker,
217-
emotion, alignment, keyword spotting, etc.
297+
Returns a ``Model`` handle that can be used for inference directly::
298+
299+
model = asr.load_model("SenseVoiceSmall", vad_model="fsmn-vad")
300+
result = model.infer(audio="test.wav")
301+
result = model(audio="test.wav") # shorthand
302+
model.unload()
218303
219304
Model names are automatically resolved to the correct hub-specific
220305
ID based on the detected network region. For example,
@@ -224,7 +309,7 @@ def load_model(
224309
225310
Args:
226311
model: Model name or ID (e.g. "SenseVoiceSmall", "fsmn-vad").
227-
name: Cache key for this model instance.
312+
name: Server-side cache key. Auto-generated if None.
228313
vad_model: VAD model for ASR pipeline composition.
229314
punc_model: Punctuation model for ASR pipeline.
230315
spk_model: Speaker model for ASR pipeline.
@@ -246,8 +331,12 @@ def load_model(
246331
- spk_kwargs (dict): Extra speaker model parameters.
247332
248333
Returns:
249-
{"name": str, "status": "loaded" | "already_loaded"}
334+
A ``Model`` handle for inference and lifecycle management.
250335
"""
336+
if name is None:
337+
self._model_counter += 1
338+
name = f"model_{self._model_counter}"
339+
251340
if hub is None:
252341
hub = get_hub()
253342

@@ -271,71 +360,29 @@ def load_model(
271360
params["name"] = name # always include name
272361
params.update(kwargs)
273362

274-
return self._rpc_call("load_model", params, timeout=600)
363+
self._rpc_call("load_model", params, timeout=600)
364+
return Model(self, name)
275365

276-
def unload_model(self, name: str = "default") -> dict:
277-
"""Unload a model and free memory."""
366+
def unload_model(self, name: str) -> dict:
367+
"""Unload a model by name. Prefer ``model.unload()`` instead."""
278368
return self._rpc_call("unload_model", {"name": name})
279369

280370
def infer(
281371
self,
372+
name: str,
373+
*,
282374
audio: str = None,
283375
text: str = None,
284376
audio_bytes: bytes = None,
285-
name: str = "default",
286-
language: str = None,
287-
use_itn: bool = None,
288-
batch_size: int = None,
289-
hotword: str = None,
290-
merge_vad: bool = None,
291-
merge_length_s: float = None,
292-
output_timestamp: bool = None,
293377
**kwargs,
294378
) -> list:
295-
"""Universal inference — works with ANY loaded FunASR model.
296-
297-
Calls model.generate(input=..., **kwargs) on the server side.
298-
Provide exactly one of: audio, text, or audio_bytes.
379+
"""Run inference on a loaded model by name.
299380
300-
Args:
301-
audio: Path to audio file (for ASR/VAD/speaker models).
302-
text: Text string (for punctuation/text models).
303-
audio_bytes: Raw audio bytes (WAV/MP3/etc.).
304-
name: Model cache key (default: "default").
305-
language: Language code (e.g. "zh", "en", "ja").
306-
use_itn: Enable inverse text normalization.
307-
batch_size: Inference batch size.
308-
hotword: Hotword file path or hotword string.
309-
merge_vad: Merge short VAD segments.
310-
merge_length_s: Max merge length in seconds.
311-
output_timestamp: Include timestamps in output.
312-
**kwargs: Additional generate() parameters. Common options:
313-
- itn (bool): Inverse text normalization (some models).
314-
- text_norm (str): Text normalization mode.
315-
- batch_size_s (int): Batch size in seconds for VAD inference.
316-
- data_type (str): Input data type hint.
317-
318-
Returns:
319-
List of result dicts. Structure depends on model type:
320-
- ASR: [{"key": ..., "text": ...}]
321-
- VAD: [{"key": ..., "value": [[start_ms, end_ms], ...]}]
322-
- Punctuation: [{"key": ..., "text": ...}]
381+
Prefer ``model.infer()`` or ``model()`` instead.
323382
"""
324-
# Build generate kwargs, filtering out None values
325-
named_generate = {
326-
"language": language,
327-
"use_itn": use_itn,
328-
"batch_size": batch_size,
329-
"hotword": hotword,
330-
"merge_vad": merge_vad,
331-
"merge_length_s": merge_length_s,
332-
"output_timestamp": output_timestamp,
333-
}
334383
params = {"name": name}
335-
params.update({k: v for k, v in named_generate.items() if v is not None})
336-
params.update(kwargs)
384+
params.update({k: v for k, v in kwargs.items() if v is not None})
337385

338-
# Set input
339386
if audio_bytes is not None:
340387
params["input_base64"] = base64.b64encode(audio_bytes).decode()
341388
elif audio is not None:
@@ -354,32 +401,6 @@ def infer(
354401
result = self._rpc_call("infer", params, timeout=600)
355402
return result.get("results", [])
356403

357-
def transcribe(
358-
self,
359-
audio: str = None,
360-
audio_bytes: bytes = None,
361-
name: str = "default",
362-
**kwargs,
363-
) -> list:
364-
"""Transcribe audio — convenience alias for infer().
365-
366-
Args:
367-
audio: Path to audio file.
368-
audio_bytes: Raw audio bytes (WAV/MP3/etc.)
369-
name: Model cache key.
370-
**kwargs: Additional generate() parameters
371-
(language, use_itn, hotword, etc.)
372-
373-
Returns:
374-
List of result dicts with at least "text" key.
375-
"""
376-
return self.infer(
377-
audio=audio,
378-
audio_bytes=audio_bytes,
379-
name=name,
380-
**kwargs,
381-
)
382-
383404
def execute(self, code: str, return_var: str = "result", **kwargs) -> dict:
384405
"""Execute arbitrary Python code in the server environment.
385406

0 commit comments

Comments
 (0)