66Usage:
77 from funasr_server import FunASR
88
9- asr = FunASR()
10- asr.ensure_installed()
11- asr.start()
9+ with FunASR() as asr:
10+ # load_model() returns a Model handle
11+ model = asr.load_model("SenseVoiceSmall", vad_model="fsmn-vad")
12+ result = model.infer(audio="audio.wav")
13+ result = model(audio="audio.wav") # shorthand
1214
13- # ASR — model name auto-resolved to correct hub
14- asr.load_model(model="SenseVoiceSmall ")
15- result = asr.infer(audio="audio.wav", language="zh", use_itn=True )
15+ # Multiple models
16+ vad = asr.load_model("fsmn-vad ")
17+ punc = asr.load_model("ct-punc" )
1618
17- # VAD (standalone)
18- asr.load_model(model="fsmn-vad", name="vad")
19- result = asr.infer(audio="audio.wav", name="vad")
19+ segments = vad(audio="audio.wav")
20+ text = punc(text="你好世界今天天气真好")
2021
21- # ASR + VAD pipeline
22- asr.load_model(model="SenseVoiceSmall", vad_model="fsmn-vad", name="asr_vad")
23- result = asr.infer(audio="audio.wav", name="asr_vad")
24-
25- # Punctuation model
26- asr.load_model(model="ct-punc", name="punc")
27- result = asr.infer(text="你好世界今天天气真好", name="punc")
28-
29- asr.stop()
22+ model.unload()
3023"""
3124
3225import base64
@@ -57,6 +50,93 @@ def __init__(self, code: int, message: str, data: str = None):
5750 super ().__init__ (message )
5851
5952
53+ class Model :
54+ """Handle to a loaded model on the server.
55+
56+ Returned by ``FunASR.load_model()``. Provides ``infer()`` and
57+ ``unload()`` without needing to pass a name string around.
58+
59+ Usage::
60+
61+ model = asr.load_model("SenseVoiceSmall", vad_model="fsmn-vad")
62+ result = model.infer(audio="test.wav")
63+ # or just call it directly
64+ result = model(audio="test.wav")
65+ model.unload()
66+ """
67+
68+ def __init__ (self , client : "FunASR" , name : str ):
69+ self ._client = client
70+ self .name = name
71+
72+ def infer (
73+ self ,
74+ audio : str = None ,
75+ text : str = None ,
76+ audio_bytes : bytes = None ,
77+ language : str = None ,
78+ use_itn : bool = None ,
79+ batch_size : int = None ,
80+ hotword : str = None ,
81+ merge_vad : bool = None ,
82+ merge_length_s : float = None ,
83+ output_timestamp : bool = None ,
84+ ** kwargs ,
85+ ) -> list :
86+ """Run inference on this model.
87+
88+ Args:
89+ audio: Path to audio file.
90+ text: Text string (for punctuation models).
91+ audio_bytes: Raw audio bytes.
92+ language: Language code (e.g. "zh", "en", "ja").
93+ use_itn: Enable inverse text normalization.
94+ batch_size: Inference batch size.
95+ hotword: Hotword string.
96+ merge_vad: Merge short VAD segments.
97+ merge_length_s: Max merge length in seconds.
98+ output_timestamp: Include timestamps in output.
99+ **kwargs: Additional generate() parameters.
100+
101+ Returns:
102+ List of result dicts.
103+ """
104+ return self ._client .infer (
105+ audio = audio , text = text , audio_bytes = audio_bytes ,
106+ name = self .name ,
107+ language = language , use_itn = use_itn , batch_size = batch_size ,
108+ hotword = hotword , merge_vad = merge_vad ,
109+ merge_length_s = merge_length_s , output_timestamp = output_timestamp ,
110+ ** kwargs ,
111+ )
112+
113+ def transcribe (
114+ self ,
115+ audio : str = None ,
116+ audio_bytes : bytes = None ,
117+ ** kwargs ,
118+ ) -> list :
119+ """Transcribe audio — convenience alias for infer()."""
120+ return self .infer (audio = audio , audio_bytes = audio_bytes , ** kwargs )
121+
122+ def unload (self ) -> dict :
123+ """Unload this model from the server."""
124+ return self ._client .unload_model (name = self .name )
125+
126+ def __call__ (
127+ self ,
128+ audio : str = None ,
129+ text : str = None ,
130+ audio_bytes : bytes = None ,
131+ ** kwargs ,
132+ ) -> list :
133+ """Shorthand for infer()."""
134+ return self .infer (audio = audio , text = text , audio_bytes = audio_bytes , ** kwargs )
135+
136+ def __repr__ (self ):
137+ return f"Model(name={ self .name !r} )"
138+
139+
60140class FunASR :
61141 """Client for the FunASR inference server.
62142
@@ -81,6 +161,7 @@ def __init__(
81161 self .installer = Installer (str (self .runtime_dir ))
82162 self ._process : Optional [subprocess .Popen ] = None
83163 self ._rpc_id = 0
164+ self ._model_counter = 0
84165
85166 # ------------------------------------------------------------------
86167 # Lifecycle
@@ -199,7 +280,7 @@ def health(self) -> dict:
199280 def load_model (
200281 self ,
201282 model : str ,
202- name : str = "default" ,
283+ name : str = None ,
203284 vad_model : str = None ,
204285 punc_model : str = None ,
205286 spk_model : str = None ,
@@ -210,11 +291,15 @@ def load_model(
210291 fp16 : bool = None ,
211292 disable_update : bool = None ,
212293 ** kwargs ,
213- ) -> dict :
294+ ) -> "Model" :
214295 """Load any FunASR model via AutoModel.
215296
216- Works with ALL model types: ASR, VAD, punctuation, speaker,
217- emotion, alignment, keyword spotting, etc.
297+ Returns a ``Model`` handle that can be used for inference directly::
298+
299+ model = asr.load_model("SenseVoiceSmall", vad_model="fsmn-vad")
300+ result = model.infer(audio="test.wav")
301+ result = model(audio="test.wav") # shorthand
302+ model.unload()
218303
219304 Model names are automatically resolved to the correct hub-specific
220305 ID based on the detected network region. For example,
@@ -224,7 +309,7 @@ def load_model(
224309
225310 Args:
226311 model: Model name or ID (e.g. "SenseVoiceSmall", "fsmn-vad").
227- name: Cache key for this model instance .
312+ name: Server-side cache key. Auto-generated if None .
228313 vad_model: VAD model for ASR pipeline composition.
229314 punc_model: Punctuation model for ASR pipeline.
230315 spk_model: Speaker model for ASR pipeline.
@@ -246,8 +331,12 @@ def load_model(
246331 - spk_kwargs (dict): Extra speaker model parameters.
247332
248333 Returns:
249- {"name": str, "status": "loaded" | "already_loaded"}
334+ A ``Model`` handle for inference and lifecycle management.
250335 """
336+ if name is None :
337+ self ._model_counter += 1
338+ name = f"model_{ self ._model_counter } "
339+
251340 if hub is None :
252341 hub = get_hub ()
253342
@@ -271,71 +360,29 @@ def load_model(
271360 params ["name" ] = name # always include name
272361 params .update (kwargs )
273362
274- return self ._rpc_call ("load_model" , params , timeout = 600 )
363+ self ._rpc_call ("load_model" , params , timeout = 600 )
364+ return Model (self , name )
275365
276- def unload_model (self , name : str = "default" ) -> dict :
277- """Unload a model and free memory ."""
366+ def unload_model (self , name : str ) -> dict :
367+ """Unload a model by name. Prefer ``model.unload()`` instead ."""
278368 return self ._rpc_call ("unload_model" , {"name" : name })
279369
280370 def infer (
281371 self ,
372+ name : str ,
373+ * ,
282374 audio : str = None ,
283375 text : str = None ,
284376 audio_bytes : bytes = None ,
285- name : str = "default" ,
286- language : str = None ,
287- use_itn : bool = None ,
288- batch_size : int = None ,
289- hotword : str = None ,
290- merge_vad : bool = None ,
291- merge_length_s : float = None ,
292- output_timestamp : bool = None ,
293377 ** kwargs ,
294378 ) -> list :
295- """Universal inference — works with ANY loaded FunASR model.
296-
297- Calls model.generate(input=..., **kwargs) on the server side.
298- Provide exactly one of: audio, text, or audio_bytes.
379+ """Run inference on a loaded model by name.
299380
300- Args:
301- audio: Path to audio file (for ASR/VAD/speaker models).
302- text: Text string (for punctuation/text models).
303- audio_bytes: Raw audio bytes (WAV/MP3/etc.).
304- name: Model cache key (default: "default").
305- language: Language code (e.g. "zh", "en", "ja").
306- use_itn: Enable inverse text normalization.
307- batch_size: Inference batch size.
308- hotword: Hotword file path or hotword string.
309- merge_vad: Merge short VAD segments.
310- merge_length_s: Max merge length in seconds.
311- output_timestamp: Include timestamps in output.
312- **kwargs: Additional generate() parameters. Common options:
313- - itn (bool): Inverse text normalization (some models).
314- - text_norm (str): Text normalization mode.
315- - batch_size_s (int): Batch size in seconds for VAD inference.
316- - data_type (str): Input data type hint.
317-
318- Returns:
319- List of result dicts. Structure depends on model type:
320- - ASR: [{"key": ..., "text": ...}]
321- - VAD: [{"key": ..., "value": [[start_ms, end_ms], ...]}]
322- - Punctuation: [{"key": ..., "text": ...}]
381+ Prefer ``model.infer()`` or ``model()`` instead.
323382 """
324- # Build generate kwargs, filtering out None values
325- named_generate = {
326- "language" : language ,
327- "use_itn" : use_itn ,
328- "batch_size" : batch_size ,
329- "hotword" : hotword ,
330- "merge_vad" : merge_vad ,
331- "merge_length_s" : merge_length_s ,
332- "output_timestamp" : output_timestamp ,
333- }
334383 params = {"name" : name }
335- params .update ({k : v for k , v in named_generate .items () if v is not None })
336- params .update (kwargs )
384+ params .update ({k : v for k , v in kwargs .items () if v is not None })
337385
338- # Set input
339386 if audio_bytes is not None :
340387 params ["input_base64" ] = base64 .b64encode (audio_bytes ).decode ()
341388 elif audio is not None :
@@ -354,32 +401,6 @@ def infer(
354401 result = self ._rpc_call ("infer" , params , timeout = 600 )
355402 return result .get ("results" , [])
356403
357- def transcribe (
358- self ,
359- audio : str = None ,
360- audio_bytes : bytes = None ,
361- name : str = "default" ,
362- ** kwargs ,
363- ) -> list :
364- """Transcribe audio — convenience alias for infer().
365-
366- Args:
367- audio: Path to audio file.
368- audio_bytes: Raw audio bytes (WAV/MP3/etc.)
369- name: Model cache key.
370- **kwargs: Additional generate() parameters
371- (language, use_itn, hotword, etc.)
372-
373- Returns:
374- List of result dicts with at least "text" key.
375- """
376- return self .infer (
377- audio = audio ,
378- audio_bytes = audio_bytes ,
379- name = name ,
380- ** kwargs ,
381- )
382-
383404 def execute (self , code : str , return_var : str = "result" , ** kwargs ) -> dict :
384405 """Execute arbitrary Python code in the server environment.
385406
0 commit comments