Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/src/sdp/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ used in the downstream processing for additional enhancement or filtering.
.. autodata:: sdp.processors.AudioLid
:annotation:

.. autodata:: sdp.processors.FastTextLangIdClassifier
:annotation:

Text-only processors
####################

Expand Down
3 changes: 2 additions & 1 deletion requirements/main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ datasets>=2.14.0,<3.0.0
# pip install pytorch-lightning nvidia-cublas-cu12 nvidia-cudnn-cu12==9.* faster_whisper
# export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'`
# for vLLMInference processor is required: pip install "optree>=0.13.0" vllm
# for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]==2.2.1"
# for FastTextLangIdClassifier processor is required: pip install fasttext
# for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]==2.2.1"
1 change: 1 addition & 0 deletions sdp/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
from sdp.processors.inference.asr.utils.whisper_hallucinations import DetectWhisperHallucinationFeatures
from sdp.processors.inference.asr.utils.rttm import GetRttmSegments, SplitAudioFile
from sdp.processors.inference.nlp.nemo.pc_inference import PCInference
from sdp.processors.inference.nlp.fasttext.fasttext import FastTextLangIdClassifier
from sdp.processors.inference.llm.vllm.vllm import vLLMInference
from sdp.processors.inference.llm.utils.qwen_cleaning import CleanQwenGeneration

Expand Down
124 changes: 124 additions & 0 deletions sdp/processors/inference/nlp/fasttext/fasttext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import requests
import tempfile
import wget

from sdp.logging import logger
from sdp.processors.base_processor import BaseParallelProcessor, DataEntry


class FastTextLangIdClassifier(BaseParallelProcessor):
"""
This processor supports language identification using pretrained FastText models.
It classifies text and adds the predicted label and probability to the dataset entry.
If needed, it downloads the model, loads it into memory, and performs prediction on the
specified input text field.

Args:
model_name_or_path (str): Path to a FastText model file or the name of a supported remote model
('lid.176.bin' or 'lid.176.ftz').
text_field (str): The name of the field in the dataset entry that contains the input text for classification.
output_field (str): The name of the field to store the predicted label. Defaults to "label".
top_k (int): The number of top predictions to return. Defaults to 1 (-1 for all).
cache_dir (str, optional): Directory to store the downloaded model file. If not provided, a temporary
directory is used.
**kwargs: Additional keyword arguments passed to `BaseParallelProcessor`.

Returns:
A manifest where each entry contains the original data fields plus
- `<output_field>`: The predicted label (e.g., language code for `lid.176.bin`),
- `<output_field>_prob`: The probability of the prediction.

Note:
Make sure to install `fasttext` before using this processor:
`pip install fasttext`
"""

SUPPROTED_MODELS_URLS = {
'lid.176.bin' : 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin',
'lid.176.ftz' : 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz'
}

def __init__(
self,
model_name_or_path: str,
text_field: str,
output_field: str = "label",
top_k: int = 1,
cache_dir: str = None,
**kwargs
):
super().__init__(**kwargs)
self.model_name_or_path = model_name_or_path
self.text_field = text_field
self.output_field = output_field
self.cache_dir = cache_dir
self.top_k = top_k
self._model = None

def _download_model(self):
"""Downloads the FastText model from a predefined URL and stores it in the cache directory."""
model_url = self.SUPPROTED_MODELS_URLS[self.model_name_or_path]
logger.info(f'Downloading {self.model_name_or_path}..')
response = requests.get(model_url)

if response.status_code != 200:
raise requests.exceptions.RequestException(
f"Failed to download model file. Status code: {response.status_code}"
)

if self.cache_dir is None:
self.cache_dir = tempfile.mkdtemp()
os.makedirs(self.cache_dir, exist_ok=True)

self.model_name_or_path = wget.download(model_url, out=self.cache_dir)
logger.info(f'Model `{self.model_name_or_path}` has been downloaded to {self.cache_dir}.')

def prepare(self):
"""
Prepares the model for classification:
- Checks if the model file exists locally.
- Downloads the model if only the name is given and it's known.
- Raises ValueError if the path or model name is invalid.
"""
import fasttext

if not os.path.exists(self.model_name_or_path):
if self.cache_dir and os.path.exists(os.path.join(self.cache_dir, self.model_name_or_path)):
self.model_name_or_path = os.path.join(self.cache_dir, self.model_name_or_path)
elif self.model_name_or_path in self.SUPPROTED_MODELS_URLS:
self._download_model()
else:
raise ValueError(f'Current model is not supported or filepath is invalid: {self.model_name_or_path}.')

self._model = fasttext.load_model(self.model_name_or_path)

def process_dataset_entry(self, data_entry: dict):
"""Applies the classifier to a single dataset entry."""
text = data_entry[self.text_field].strip().replace("\n", " ")
label, prob = self._model.predict(text)
if self.top_k == 1:
data_entry[self.output_field] = label[0].replace('__label__', '')
data_entry[f"{self.output_field}_prob"] = prob[0]
else:
max_k = len(label) if self.top_k == -1 else self.top_k

for _label, _prob, top_i in zip(label, prob, range(1, max_k + 1)):
data_entry[f"{self.output_field}_{top_i}"] = _label.replace('__label__', '')
data_entry[f"{self.output_field}_prob_{top_i}"] = _prob

return [DataEntry(data=data_entry)]
57 changes: 57 additions & 0 deletions tests/test_fasttext_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from sdp.processors.inference.nlp.fasttext.fasttext import FastTextLangIdClassifier


@pytest.fixture(scope="module")
def classifier():
processor = FastTextLangIdClassifier(
model_name_or_path="lid.176.ftz",
text_field="text",
output_field="lang",
num_workers=1,
batch_size=1,
)
processor.prepare()
return processor


@pytest.mark.parametrize("text,expected_lang", [
("Hello, how are you?", "en"),
("Bonjour tout le monde", "fr"),
("Привет, как дела?", "ru"),
("Hola, ¿cómo estás?", "es"),
])
def test_language_identification(classifier, text, expected_lang):
input_entry = {"text": text}
result = classifier.process_dataset_entry(input_entry)

assert isinstance(result, list)
assert len(result) == 1

output = result[0].data
assert "lang" in output
assert "lang_prob" in output

predicted_lang = output["lang"]
prob = output["lang_prob"]

assert isinstance(predicted_lang, str)
assert 0 <= prob <= 1.0

#Exact matching may depend on the model, so we compare based on presence in the top predictions.
assert predicted_lang == expected_lang, f"Expected: {expected_lang}, got: {predicted_lang}"
Loading