Skip to content

Commit 540c11b

Browse files
committed
FastTextLangIdClassifier processor implementation
Signed-off-by: Sasha Meister <ameister@nvidia.com>
1 parent 93cfc46 commit 540c11b

File tree

5 files changed

+176
-1
lines changed

5 files changed

+176
-1
lines changed

docs/src/sdp/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ used in the downstream processing for additional enhancement or filtering.
208208
.. autodata:: sdp.processors.AudioLid
209209
:annotation:
210210

211+
.. autodata:: sdp.processors.FastTextLangIdClassifier
212+
:annotation:
213+
211214
Text-only processors
212215
####################
213216

requirements/main.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ datasets>=2.14.0,<3.0.0
2828
# for FasterWhisperInference processor is required:
2929
# pip install pytorch-lightning nvidia-cublas-cu12 nvidia-cudnn-cu12==9.* faster_whisper
3030
# export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'`
31-
# for vLLMInference processor is required: pip install "optree>=0.13.0" vllm
31+
# for vLLMInference processor is required: pip install "optree>=0.13.0" vllm
32+
# for FastTextLangIdClassifier processor is required: pip install fasttext

sdp/processors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
from sdp.processors.inference.asr.utils.whisper_hallucinations import DetectWhisperHallucinationFeatures
149149
from sdp.processors.inference.asr.utils.rttm import GetRttmSegments, SplitAudioFile
150150
from sdp.processors.inference.nlp.nemo.pc_inference import PCInference
151+
from sdp.processors.inference.nlp.fasttext.fasttext import FastTextLangIdClassifier
151152
from sdp.processors.inference.llm.vllm.vllm import vLLMInference
152153
from sdp.processors.inference.llm.utils.qwen_cleaning import CleanQwenGeneration
153154

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import requests
17+
import tempfile
18+
import wget
19+
20+
from sdp.logging import logger
21+
from sdp.processors.base_processor import BaseParallelProcessor, DataEntry
22+
23+
24+
class FastTextLangIdClassifier(BaseParallelProcessor):
25+
"""
26+
This processor supports language identification using pretrained FastText models.
27+
It classifies text and adds the predicted label and probability to the dataset entry.
28+
If needed, it downloads the model, loads it into memory, and performs prediction on the
29+
specified input text field.
30+
31+
Args:
32+
model_name_or_path (str): Path to a FastText model file or the name of a supported remote model ('lid.176.bin' or 'lid.176.ftz').
33+
text_field (str): The name of the field in the dataset entry that contains the input text for classification.
34+
output_field (str): The name of the field to store the predicted label. Defaults to "label".
35+
cache_dir (str, optional): Directory to store the downloaded model file. If not provided, a temporary directory is used.
36+
**kwargs: Additional keyword arguments passed to `BaseParallelProcessor`.
37+
38+
Returns:
39+
A manifest where each entry contains the original data fields plus:
40+
- `<output_field>`: The predicted label (e.g., language code for `lid.176.bin`).
41+
- `<output_field>_prob`: The probability of the prediction.
42+
43+
.. note::
44+
Make sure to install `fasttext` before using this processor:
45+
pip install fasttext
46+
47+
"""
48+
49+
SUPPROTED_MODELS_URLS = {
50+
'lid.176.bin' : 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin',
51+
'lid.176.ftz' : 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz'
52+
}
53+
54+
def __init__(
55+
self,
56+
model_name_or_path: str,
57+
text_field: str,
58+
output_field: str = "label",
59+
cache_dir: str = None,
60+
**kwargs
61+
):
62+
super().__init__(**kwargs)
63+
self.model_name_or_path = model_name_or_path
64+
self.text_field = text_field
65+
self.output_field = output_field
66+
self.cache_dir = cache_dir
67+
self._model = None
68+
69+
def _download_model(self):
70+
"""Downloads the FastText model from a predefined URL and stores it in the cache directory."""
71+
model_url = self.SUPPROTED_MODELS_URLS[self.model_name_or_path]
72+
logger.info(f'Downloading {self.model_name_or_path}..')
73+
response = requests.get(model_url)
74+
75+
if response.status_code != 200:
76+
raise requests.exceptions.RequestException(
77+
f"Failed to download model file. Status code: {response.status_code}"
78+
)
79+
80+
if self.cache_dir is None:
81+
self.cache_dir = tempfile.mkdtemp()
82+
os.makedirs(self.cache_dir, exist_ok=True)
83+
84+
self.model_name_or_path = wget.download(model_url, out=self.cache_dir)
85+
logger.info(f'Model `{self.model_name_or_path}` has been downloaded to {self.cache_dir}.')
86+
87+
def prepare(self):
88+
"""
89+
Prepares the model for classification:
90+
- Checks if the model file exists locally.
91+
- Downloads the model if only the name is given and it's known.
92+
- Raises ValueError if the path or model name is invalid.
93+
"""
94+
import fasttext
95+
96+
if not os.path.exists(self.model_name_or_path):
97+
if self.cache_dir and os.path.exists(os.path.join(self.cache_dir, self.model_name_or_path)):
98+
self.model_name_or_path = os.path.join(self.cache_dir, self.model_name_or_path)
99+
elif self.model_name_or_path in self.SUPPROTED_MODELS_URLS:
100+
self._download_model()
101+
else:
102+
raise ValueError(f'Current model is not supported or filepath is invalid: {self.model_name_or_path}.')
103+
104+
self._model = fasttext.load_model(self.model_name_or_path)
105+
106+
def process_dataset_entry(self, data_entry: dict):
107+
"""Applies the classifier to a single dataset entry."""
108+
text = data_entry[self.text_field].strip().replace("\n", " ")
109+
label, prob = self._model.predict(text)
110+
data_entry[self.output_field] = label[0].replace('__label__', '')
111+
data_entry[f"{self.output_field}_prob"] = prob[0]
112+
113+
return [DataEntry(data=data_entry)]

tests/test_fasttext_inference.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from sdp.processors.inference.nlp.fasttext.fasttext import FastTextLangIdClassifier
18+
19+
20+
@pytest.fixture(scope="module")
21+
def classifier():
22+
processor = FastTextLangIdClassifier(
23+
model_name_or_path="lid.176.ftz",
24+
text_field="text",
25+
output_field="lang",
26+
num_workers=1,
27+
batch_size=1,
28+
)
29+
processor.prepare()
30+
return processor
31+
32+
33+
@pytest.mark.parametrize("text,expected_lang", [
34+
("Hello, how are you?", "en"),
35+
("Bonjour tout le monde", "fr"),
36+
("Привет, как дела?", "ru"),
37+
("Hola, ¿cómo estás?", "es"),
38+
])
39+
def test_language_identification(classifier, text, expected_lang):
40+
input_entry = {"text": text}
41+
result = classifier.process_dataset_entry(input_entry)
42+
43+
assert isinstance(result, list)
44+
assert len(result) == 1
45+
46+
output = result[0].data
47+
assert "lang" in output
48+
assert "lang_prob" in output
49+
50+
predicted_lang = output["lang"]
51+
prob = output["lang_prob"]
52+
53+
assert isinstance(predicted_lang, str)
54+
assert 0 <= prob <= 1.0
55+
56+
#Exact matching may depend on the model, so we compare based on presence in the top predictions.
57+
assert predicted_lang == expected_lang, f"Expected: {expected_lang}, got: {predicted_lang}"

0 commit comments

Comments
 (0)