Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@
from funasr.utils import export_utils
from funasr.utils import misc


def _resolve_ncpu(config, fallback=4):
"""Return a positive integer representing CPU threads from config."""
value = config.get("ncpu", fallback)
try:
value = int(value)
except (TypeError, ValueError):
value = fallback
return max(value, 1)

try:
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.models.campplus.cluster_backend import ClusterBackend
Expand Down Expand Up @@ -132,6 +142,7 @@ def __init__(self, **kwargs):
vad_kwargs["model"] = vad_model
vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
vad_kwargs["device"] = kwargs["device"]
vad_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4))
vad_model, vad_kwargs = self.build_model(**vad_kwargs)

# if punc_model is not None, build punc model else None
Expand All @@ -142,6 +153,7 @@ def __init__(self, **kwargs):
punc_kwargs["model"] = punc_model
punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
punc_kwargs["device"] = kwargs["device"]
punc_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4))
punc_model, punc_kwargs = self.build_model(**punc_kwargs)

# if spk_model is not None, build spk model else None
Expand All @@ -155,6 +167,7 @@ def __init__(self, **kwargs):
spk_kwargs["model"] = spk_model
spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
spk_kwargs["device"] = kwargs["device"]
spk_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4))
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
self.cb_model = ClusterBackend(**cb_kwargs).to(kwargs["device"])
spk_mode = kwargs.get("spk_mode", "punc_segment")
Expand All @@ -171,6 +184,7 @@ def __init__(self, **kwargs):
self.spk_model = spk_model
self.spk_kwargs = spk_kwargs
self.model_path = kwargs.get("model_path")
self._store_base_configs()

@staticmethod
def build_model(**kwargs):
Expand All @@ -190,7 +204,10 @@ def build_model(**kwargs):
kwargs["batch_size"] = 1
kwargs["device"] = device

torch.set_num_threads(kwargs.get("ncpu", 4))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Setting the number of threads every time build_model is called might be excessive. Consider setting it only once at the beginning of the program or when the ncpu value actually changes to avoid unnecessary overhead.

Also, consider logging when the number of threads is actually changed for debugging purposes.

ncpu = _resolve_ncpu(kwargs, 4)
kwargs["ncpu"] = ncpu
if torch.get_num_threads() != ncpu:
torch.set_num_threads(ncpu)
Comment on lines +209 to +210
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check is good to prevent unnecessary calls to torch.set_num_threads(). However, it might be beneficial to add a log message here to indicate when the number of threads is being changed, which can help in debugging and monitoring.

Consider adding a log message like logging.info(f"Setting number of threads to {ncpu}").


# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
Expand Down Expand Up @@ -302,6 +319,7 @@ def __call__(self, *args, **cfg):
return res

def generate(self, input, input_len=None, progress_callback=None, **cfg):
self._reset_runtime_configs()
if self.vad_model is None:
return self.inference(
input, input_len=input_len, progress_callback=progress_callback, **cfg
Expand All @@ -322,6 +340,8 @@ def inference(
progress_callback=None,
**cfg,
):
if kwargs is None:
self._reset_runtime_configs()
Comment on lines +343 to +344
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

It seems like the condition if kwargs is None: is redundant here, as kwargs is assigned self.kwargs if it's initially None on line 345. This check will always be false. Removing this check will simplify the code.

Also, the call to self._reset_runtime_configs() is only done when kwargs is None. This means that if the user provides a kwargs argument, the runtime configs will not be reset, which might lead to unexpected behavior. Consider always calling self._reset_runtime_configs() at the beginning of the function to ensure consistent behavior.

Suggested change
if kwargs is None:
self._reset_runtime_configs()
self._reset_runtime_configs()
kwargs = self.kwargs if kwargs is None else kwargs

kwargs = self.kwargs if kwargs is None else kwargs
if "cache" in kwargs:
kwargs.pop("cache")
Expand Down Expand Up @@ -397,6 +417,7 @@ def inference(
return asr_result_list

def inference_with_vad(self, input, input_len=None, **cfg):
self._reset_runtime_configs()
kwargs = self.kwargs
# step.1: compute the vad model
deep_update(self.vad_kwargs, cfg)
Expand Down Expand Up @@ -691,3 +712,37 @@ def export(self, input=None, **cfg):
export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)

return export_dir

def _store_base_configs(self):
"""Snapshot base kwargs for all submodules to allow reset before inference."""
baseline = {}
for name in dir(self):
if not name.endswith("kwargs"):
continue
value = getattr(self, name, None)
if isinstance(value, dict):
baseline[name] = copy.deepcopy(value)
# include primary kwargs explicitly
baseline["kwargs"] = copy.deepcopy(self.kwargs)
self._base_kwargs_map = baseline

def _reset_runtime_configs(self):
"""Ensure runtime kwargs reset to baseline defaults before inference."""
base_map = getattr(self, "_base_kwargs_map", None)
if not base_map:
return

for name, base in base_map.items():
restored = copy.deepcopy(base)
setattr(self, name, restored)

ncpu = _resolve_ncpu(self.kwargs, 4)
self.kwargs["ncpu"] = ncpu
for name, value in base_map.items():
if name == "kwargs":
continue
config = getattr(self, name, None)
if isinstance(config, dict):
config.setdefault("ncpu", ncpu)
Comment on lines +741 to +746
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The loop iterates through base_map.items(), but only uses the name in the loop body. The value is not used. This might indicate a potential misunderstanding of the intended logic. If the intention is to iterate through the keys of base_map and use them to access attributes of self, then the loop should be changed to iterate through base_map.keys().

Also, the condition if name == "kwargs": continue skips the primary kwargs from having setdefault called on it. This seems inconsistent with the handling of the other kwargs and might lead to unexpected behavior. Consider whether the primary kwargs should also have setdefault called on it.

if torch.get_num_threads() != ncpu:
torch.set_num_threads(ncpu)
Comment on lines +747 to +748
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a log message here to indicate when the number of threads is being reset, which can help in debugging and monitoring. For example: logging.info(f"Resetting number of threads to {ncpu}").