-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Work around FunASR kwargs state leaks #2698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -30,6 +30,16 @@ | |||||||||
| from funasr.utils import export_utils | ||||||||||
| from funasr.utils import misc | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _resolve_ncpu(config, fallback=4): | ||||||||||
| """Return a positive integer representing CPU threads from config.""" | ||||||||||
| value = config.get("ncpu", fallback) | ||||||||||
| try: | ||||||||||
| value = int(value) | ||||||||||
| except (TypeError, ValueError): | ||||||||||
| value = fallback | ||||||||||
| return max(value, 1) | ||||||||||
|
|
||||||||||
| try: | ||||||||||
| from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk | ||||||||||
| from funasr.models.campplus.cluster_backend import ClusterBackend | ||||||||||
|
|
@@ -132,6 +142,7 @@ def __init__(self, **kwargs): | |||||||||
| vad_kwargs["model"] = vad_model | ||||||||||
| vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master") | ||||||||||
| vad_kwargs["device"] = kwargs["device"] | ||||||||||
| vad_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4)) | ||||||||||
| vad_model, vad_kwargs = self.build_model(**vad_kwargs) | ||||||||||
|
|
||||||||||
| # if punc_model is not None, build punc model else None | ||||||||||
|
|
@@ -142,6 +153,7 @@ def __init__(self, **kwargs): | |||||||||
| punc_kwargs["model"] = punc_model | ||||||||||
| punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master") | ||||||||||
| punc_kwargs["device"] = kwargs["device"] | ||||||||||
| punc_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4)) | ||||||||||
| punc_model, punc_kwargs = self.build_model(**punc_kwargs) | ||||||||||
|
|
||||||||||
| # if spk_model is not None, build spk model else None | ||||||||||
|
|
@@ -155,6 +167,7 @@ def __init__(self, **kwargs): | |||||||||
| spk_kwargs["model"] = spk_model | ||||||||||
| spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master") | ||||||||||
| spk_kwargs["device"] = kwargs["device"] | ||||||||||
| spk_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4)) | ||||||||||
| spk_model, spk_kwargs = self.build_model(**spk_kwargs) | ||||||||||
| self.cb_model = ClusterBackend(**cb_kwargs).to(kwargs["device"]) | ||||||||||
| spk_mode = kwargs.get("spk_mode", "punc_segment") | ||||||||||
|
|
@@ -171,6 +184,7 @@ def __init__(self, **kwargs): | |||||||||
| self.spk_model = spk_model | ||||||||||
| self.spk_kwargs = spk_kwargs | ||||||||||
| self.model_path = kwargs.get("model_path") | ||||||||||
| self._store_base_configs() | ||||||||||
|
|
||||||||||
| @staticmethod | ||||||||||
| def build_model(**kwargs): | ||||||||||
|
|
@@ -190,7 +204,10 @@ def build_model(**kwargs): | |||||||||
| kwargs["batch_size"] = 1 | ||||||||||
| kwargs["device"] = device | ||||||||||
|
|
||||||||||
| torch.set_num_threads(kwargs.get("ncpu", 4)) | ||||||||||
| ncpu = _resolve_ncpu(kwargs, 4) | ||||||||||
| kwargs["ncpu"] = ncpu | ||||||||||
| if torch.get_num_threads() != ncpu: | ||||||||||
| torch.set_num_threads(ncpu) | ||||||||||
|
Comment on lines
+209
to
+210
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check is good to prevent unnecessary calls to Consider adding a log message like |
||||||||||
|
|
||||||||||
| # build tokenizer | ||||||||||
| tokenizer = kwargs.get("tokenizer", None) | ||||||||||
|
|
@@ -302,6 +319,7 @@ def __call__(self, *args, **cfg): | |||||||||
| return res | ||||||||||
|
|
||||||||||
| def generate(self, input, input_len=None, progress_callback=None, **cfg): | ||||||||||
| self._reset_runtime_configs() | ||||||||||
| if self.vad_model is None: | ||||||||||
| return self.inference( | ||||||||||
| input, input_len=input_len, progress_callback=progress_callback, **cfg | ||||||||||
|
|
@@ -322,6 +340,8 @@ def inference( | |||||||||
| progress_callback=None, | ||||||||||
| **cfg, | ||||||||||
| ): | ||||||||||
| if kwargs is None: | ||||||||||
| self._reset_runtime_configs() | ||||||||||
|
Comment on lines
+343
to
+344
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like the condition Also, the call to
Suggested change
|
||||||||||
| kwargs = self.kwargs if kwargs is None else kwargs | ||||||||||
| if "cache" in kwargs: | ||||||||||
| kwargs.pop("cache") | ||||||||||
|
|
@@ -397,6 +417,7 @@ def inference( | |||||||||
| return asr_result_list | ||||||||||
|
|
||||||||||
| def inference_with_vad(self, input, input_len=None, **cfg): | ||||||||||
| self._reset_runtime_configs() | ||||||||||
| kwargs = self.kwargs | ||||||||||
| # step.1: compute the vad model | ||||||||||
| deep_update(self.vad_kwargs, cfg) | ||||||||||
|
|
@@ -691,3 +712,37 @@ def export(self, input=None, **cfg): | |||||||||
| export_dir = export_utils.export(model=model, data_in=data_list, **kwargs) | ||||||||||
|
|
||||||||||
| return export_dir | ||||||||||
|
|
||||||||||
| def _store_base_configs(self): | ||||||||||
| """Snapshot base kwargs for all submodules to allow reset before inference.""" | ||||||||||
| baseline = {} | ||||||||||
| for name in dir(self): | ||||||||||
| if not name.endswith("kwargs"): | ||||||||||
| continue | ||||||||||
| value = getattr(self, name, None) | ||||||||||
| if isinstance(value, dict): | ||||||||||
| baseline[name] = copy.deepcopy(value) | ||||||||||
| # include primary kwargs explicitly | ||||||||||
| baseline["kwargs"] = copy.deepcopy(self.kwargs) | ||||||||||
| self._base_kwargs_map = baseline | ||||||||||
|
|
||||||||||
| def _reset_runtime_configs(self): | ||||||||||
| """Ensure runtime kwargs reset to baseline defaults before inference.""" | ||||||||||
| base_map = getattr(self, "_base_kwargs_map", None) | ||||||||||
| if not base_map: | ||||||||||
| return | ||||||||||
|
|
||||||||||
| for name, base in base_map.items(): | ||||||||||
| restored = copy.deepcopy(base) | ||||||||||
| setattr(self, name, restored) | ||||||||||
|
|
||||||||||
| ncpu = _resolve_ncpu(self.kwargs, 4) | ||||||||||
| self.kwargs["ncpu"] = ncpu | ||||||||||
| for name, value in base_map.items(): | ||||||||||
| if name == "kwargs": | ||||||||||
| continue | ||||||||||
| config = getattr(self, name, None) | ||||||||||
| if isinstance(config, dict): | ||||||||||
| config.setdefault("ncpu", ncpu) | ||||||||||
|
Comment on lines
+741
to
+746
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The loop iterates through Also, the condition |
||||||||||
| if torch.get_num_threads() != ncpu: | ||||||||||
| torch.set_num_threads(ncpu) | ||||||||||
|
Comment on lines
+747
to
+748
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting the number of threads every time
build_modelis called might be excessive. Consider setting it only once at the beginning of the program or when thencpuvalue actually changes to avoid unnecessary overhead.Also, consider logging when the number of threads is actually changed for debugging purposes.