Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion ovos_audio_transformer_plugin_speechbrain_langdetect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@

class SpeechBrainLangClassifier(AudioLanguageDetector):
def __init__(self, config=None):
"""
Initialize the SpeechBrain language classifier plugin instance.

Parameters:
config (dict, optional): Configuration dictionary. Recognized keys:
- "model" (str): SpeechBrain model identifier or path to use (defaults to
"speechbrain/lang-id-commonlanguage_ecapa").
- "use_cuda" (bool): If true, attempt to load the model onto CUDA; otherwise
load on the default device.

Side effects:
Sets up the plugin via the superclass and initializes `self.engine` with a
SpeechBrain EncoderClassifier instance.
"""
config = config or {}
super().__init__("ovos-audio-transformer-plugin-speechbrain-langdetect", 10, config)
model = self.config.get("model") or "speechbrain/lang-id-commonlanguage_ecapa"
Expand All @@ -25,6 +39,17 @@ def __init__(self, config=None):
self.engine = EncoderClassifier.from_hparams(source=model, savedir=f"{xdg_data_home()}/speechbrain")

def signal2probs(self, signal):
"""
Map a model input signal to language probability scores.

Runs the classifier on the provided preprocessed audio signal and returns a mapping from lowercase language codes to their predicted probabilities.

Parameters:
signal: Model-ready audio input (batch tensor or structure accepted by the classifier's classify_batch).

Returns:
dict: Mapping where each key is a lowercase BCP-47-like language code (e.g., "en-us") and each value is the language probability as a float between 0 and 1.
"""
probs, _, _, _ = self.engine.classify_batch(signal)
probs = torch.softmax(probs[0], dim=0)
labels = self.engine.hparams.label_encoder.decode_ndim(range(len(probs)))
Expand Down Expand Up @@ -83,6 +108,16 @@ def signal2probs(self, signal):

# plugin api
def detect(self, audio_data: bytes, valid_langs=None):
"""
Detects the most likely language for the given audio and returns the language code with its probability.

Parameters:
audio_data (bytes | AudioData): Raw audio bytes or an AudioData instance; raw bytes will be wrapped into an AudioData with 16 kHz sample rate and 2 channels.
valid_langs (Iterable[str], optional): Iterable of allowed BCP-47-like language codes to consider (e.g., "en-US", "es-ES"); if omitted the global get_valid_languages() set is used.

Returns:
tuple: If only one language is in `valid_langs`, returns (audio_data, {}) indicating no classification was performed. Otherwise returns `(lang_code, probability)` where `lang_code` is the selected language code (string) and `probability` is the normalized confidence as a float between 0 and 1.
"""
if not isinstance(audio_data, AudioData):
audio_data = AudioData(audio_data, 16000, 2)

Expand Down Expand Up @@ -111,4 +146,4 @@ def detect(self, audio_data: bytes, valid_langs=None):
s = SpeechBrainLangClassifier()
lang, prob = s.detect(audio.get_wav_data(), valid_langs=["en-us", "es-es"])
print(lang, prob)
# en-us 0.5979952496320518
# en-us 0.5979952496320518
Loading