diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index a864dadd7..900059a38 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -30,6 +30,16 @@ from funasr.utils import export_utils from funasr.utils import misc + +def _resolve_ncpu(config, fallback=4): + """Return a positive integer representing CPU threads from config.""" + value = config.get("ncpu", fallback) + try: + value = int(value) + except (TypeError, ValueError): + value = fallback + return max(value, 1) + try: from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk from funasr.models.campplus.cluster_backend import ClusterBackend @@ -132,6 +142,7 @@ def __init__(self, **kwargs): vad_kwargs["model"] = vad_model vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master") vad_kwargs["device"] = kwargs["device"] + vad_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4)) vad_model, vad_kwargs = self.build_model(**vad_kwargs) # if punc_model is not None, build punc model else None @@ -142,6 +153,7 @@ def __init__(self, **kwargs): punc_kwargs["model"] = punc_model punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master") punc_kwargs["device"] = kwargs["device"] + punc_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4)) punc_model, punc_kwargs = self.build_model(**punc_kwargs) # if spk_model is not None, build spk model else None @@ -155,6 +167,7 @@ def __init__(self, **kwargs): spk_kwargs["model"] = spk_model spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master") spk_kwargs["device"] = kwargs["device"] + spk_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4)) spk_model, spk_kwargs = self.build_model(**spk_kwargs) self.cb_model = ClusterBackend(**cb_kwargs).to(kwargs["device"]) spk_mode = kwargs.get("spk_mode", "punc_segment") @@ -171,6 +184,7 @@ def __init__(self, **kwargs): self.spk_model = spk_model self.spk_kwargs = spk_kwargs self.model_path = kwargs.get("model_path") + self._store_base_configs() @staticmethod def build_model(**kwargs): @@ -190,7 +204,10 @@ def build_model(**kwargs): kwargs["batch_size"] = 1 kwargs["device"] = device - torch.set_num_threads(kwargs.get("ncpu", 4)) + ncpu = _resolve_ncpu(kwargs, 4) + kwargs["ncpu"] = ncpu + if torch.get_num_threads() != ncpu: + torch.set_num_threads(ncpu) # build tokenizer tokenizer = kwargs.get("tokenizer", None) @@ -302,6 +319,7 @@ def __call__(self, *args, **cfg): return res def generate(self, input, input_len=None, progress_callback=None, **cfg): + self._reset_runtime_configs() if self.vad_model is None: return self.inference( input, input_len=input_len, progress_callback=progress_callback, **cfg @@ -322,6 +340,8 @@ def inference( progress_callback=None, **cfg, ): + if kwargs is None: + self._reset_runtime_configs() kwargs = self.kwargs if kwargs is None else kwargs if "cache" in kwargs: kwargs.pop("cache") @@ -397,6 +417,7 @@ def inference( return asr_result_list def inference_with_vad(self, input, input_len=None, **cfg): + self._reset_runtime_configs() kwargs = self.kwargs # step.1: compute the vad model deep_update(self.vad_kwargs, cfg) @@ -691,3 +712,37 @@ def export(self, input=None, **cfg): export_dir = export_utils.export(model=model, data_in=data_list, **kwargs) return export_dir + + def _store_base_configs(self): + """Snapshot base kwargs for all submodules to allow reset before inference.""" + baseline = {} + for name in dir(self): + if not name.endswith("kwargs"): + continue + value = getattr(self, name, None) + if isinstance(value, dict): + baseline[name] = copy.deepcopy(value) + # include primary kwargs explicitly + baseline["kwargs"] = copy.deepcopy(self.kwargs) + self._base_kwargs_map = baseline + + def _reset_runtime_configs(self): + """Ensure runtime kwargs reset to baseline defaults before inference.""" + base_map = getattr(self, "_base_kwargs_map", None) + if not base_map: + return + + for name, base in base_map.items(): + restored = copy.deepcopy(base) + setattr(self, name, restored) + + ncpu = _resolve_ncpu(self.kwargs, 4) + self.kwargs["ncpu"] = ncpu + for name, value in base_map.items(): + if name == "kwargs": + continue + config = getattr(self, name, None) + if isinstance(config, dict): + config.setdefault("ncpu", ncpu) + if torch.get_num_threads() != ncpu: + torch.set_num_threads(ncpu)