Skip to content

Commit 1731ef7

Browse files
authored
FastTextLangIdClassifier processor implementation (#149)
* FastTextLangIdClassifier processor implementation Signed-off-by: Sasha Meister <[email protected]> * Fix docs Signed-off-by: Sasha Meister <[email protected]> * Fix docs Signed-off-by: Sasha Meister <[email protected]> * Update fasttext.py Fix docs * Make returning best or all labels optional based on top_k value Signed-off-by: Sasha Meister <[email protected]> --------- Signed-off-by: Sasha Meister <[email protected]>
1 parent c9df041 commit 1731ef7

File tree

5 files changed

+187
-1
lines changed

5 files changed

+187
-1
lines changed

docs/src/sdp/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ used in the downstream processing for additional enhancement or filtering.
208208
.. autodata:: sdp.processors.AudioLid
209209
:annotation:
210210

211+
.. autodata:: sdp.processors.FastTextLangIdClassifier
212+
:annotation:
213+
211214
Text-only processors
212215
####################
213216

requirements/main.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ datasets>=2.14.0,<3.0.0
3030
# pip install pytorch-lightning nvidia-cublas-cu12 nvidia-cudnn-cu12==9.* faster_whisper
3131
# export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'`
3232
# for vLLMInference processor is required: pip install "optree>=0.13.0" vllm
33-
# for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]==2.2.1"
33+
# for FastTextLangIdClassifier processor is required: pip install fasttext
34+
# for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]==2.2.1"

sdp/processors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
from sdp.processors.inference.asr.utils.whisper_hallucinations import DetectWhisperHallucinationFeatures
149149
from sdp.processors.inference.asr.utils.rttm import GetRttmSegments, SplitAudioFile
150150
from sdp.processors.inference.nlp.nemo.pc_inference import PCInference
151+
from sdp.processors.inference.nlp.fasttext.fasttext import FastTextLangIdClassifier
151152
from sdp.processors.inference.llm.vllm.vllm import vLLMInference
152153
from sdp.processors.inference.llm.utils.qwen_cleaning import CleanQwenGeneration
153154

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import requests
17+
import tempfile
18+
import wget
19+
20+
from sdp.logging import logger
21+
from sdp.processors.base_processor import BaseParallelProcessor, DataEntry
22+
23+
24+
class FastTextLangIdClassifier(BaseParallelProcessor):
25+
"""
26+
This processor supports language identification using pretrained FastText models.
27+
It classifies text and adds the predicted label and probability to the dataset entry.
28+
If needed, it downloads the model, loads it into memory, and performs prediction on the
29+
specified input text field.
30+
31+
Args:
32+
model_name_or_path (str): Path to a FastText model file or the name of a supported remote model
33+
('lid.176.bin' or 'lid.176.ftz').
34+
text_field (str): The name of the field in the dataset entry that contains the input text for classification.
35+
output_field (str): The name of the field to store the predicted label. Defaults to "label".
36+
top_k (int): The number of top predictions to return. Defaults to 1 (-1 for all).
37+
cache_dir (str, optional): Directory to store the downloaded model file. If not provided, a temporary
38+
directory is used.
39+
**kwargs: Additional keyword arguments passed to `BaseParallelProcessor`.
40+
41+
Returns:
42+
A manifest where each entry contains the original data fields plus
43+
- `<output_field>`: The predicted label (e.g., language code for `lid.176.bin`),
44+
- `<output_field>_prob`: The probability of the prediction.
45+
46+
Note:
47+
Make sure to install `fasttext` before using this processor:
48+
`pip install fasttext`
49+
"""
50+
51+
SUPPROTED_MODELS_URLS = {
52+
'lid.176.bin' : 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin',
53+
'lid.176.ftz' : 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz'
54+
}
55+
56+
def __init__(
57+
self,
58+
model_name_or_path: str,
59+
text_field: str,
60+
output_field: str = "label",
61+
top_k: int = 1,
62+
cache_dir: str = None,
63+
**kwargs
64+
):
65+
super().__init__(**kwargs)
66+
self.model_name_or_path = model_name_or_path
67+
self.text_field = text_field
68+
self.output_field = output_field
69+
self.cache_dir = cache_dir
70+
self.top_k = top_k
71+
self._model = None
72+
73+
def _download_model(self):
74+
"""Downloads the FastText model from a predefined URL and stores it in the cache directory."""
75+
model_url = self.SUPPROTED_MODELS_URLS[self.model_name_or_path]
76+
logger.info(f'Downloading {self.model_name_or_path}..')
77+
response = requests.get(model_url)
78+
79+
if response.status_code != 200:
80+
raise requests.exceptions.RequestException(
81+
f"Failed to download model file. Status code: {response.status_code}"
82+
)
83+
84+
if self.cache_dir is None:
85+
self.cache_dir = tempfile.mkdtemp()
86+
os.makedirs(self.cache_dir, exist_ok=True)
87+
88+
self.model_name_or_path = wget.download(model_url, out=self.cache_dir)
89+
logger.info(f'Model `{self.model_name_or_path}` has been downloaded to {self.cache_dir}.')
90+
91+
def prepare(self):
92+
"""
93+
Prepares the model for classification:
94+
- Checks if the model file exists locally.
95+
- Downloads the model if only the name is given and it's known.
96+
- Raises ValueError if the path or model name is invalid.
97+
"""
98+
import fasttext
99+
100+
if not os.path.exists(self.model_name_or_path):
101+
if self.cache_dir and os.path.exists(os.path.join(self.cache_dir, self.model_name_or_path)):
102+
self.model_name_or_path = os.path.join(self.cache_dir, self.model_name_or_path)
103+
elif self.model_name_or_path in self.SUPPROTED_MODELS_URLS:
104+
self._download_model()
105+
else:
106+
raise ValueError(f'Current model is not supported or filepath is invalid: {self.model_name_or_path}.')
107+
108+
self._model = fasttext.load_model(self.model_name_or_path)
109+
110+
def process_dataset_entry(self, data_entry: dict):
111+
"""Applies the classifier to a single dataset entry."""
112+
text = data_entry[self.text_field].strip().replace("\n", " ")
113+
label, prob = self._model.predict(text)
114+
if self.top_k == 1:
115+
data_entry[self.output_field] = label[0].replace('__label__', '')
116+
data_entry[f"{self.output_field}_prob"] = prob[0]
117+
else:
118+
max_k = len(label) if self.top_k == -1 else self.top_k
119+
120+
for _label, _prob, top_i in zip(label, prob, range(1, max_k + 1)):
121+
data_entry[f"{self.output_field}_{top_i}"] = _label.replace('__label__', '')
122+
data_entry[f"{self.output_field}_prob_{top_i}"] = _prob
123+
124+
return [DataEntry(data=data_entry)]

tests/test_fasttext_inference.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from sdp.processors.inference.nlp.fasttext.fasttext import FastTextLangIdClassifier
18+
19+
20+
@pytest.fixture(scope="module")
21+
def classifier():
22+
processor = FastTextLangIdClassifier(
23+
model_name_or_path="lid.176.ftz",
24+
text_field="text",
25+
output_field="lang",
26+
num_workers=1,
27+
batch_size=1,
28+
)
29+
processor.prepare()
30+
return processor
31+
32+
33+
@pytest.mark.parametrize("text,expected_lang", [
34+
("Hello, how are you?", "en"),
35+
("Bonjour tout le monde", "fr"),
36+
("Привет, как дела?", "ru"),
37+
("Hola, ¿cómo estás?", "es"),
38+
])
39+
def test_language_identification(classifier, text, expected_lang):
40+
input_entry = {"text": text}
41+
result = classifier.process_dataset_entry(input_entry)
42+
43+
assert isinstance(result, list)
44+
assert len(result) == 1
45+
46+
output = result[0].data
47+
assert "lang" in output
48+
assert "lang_prob" in output
49+
50+
predicted_lang = output["lang"]
51+
prob = output["lang_prob"]
52+
53+
assert isinstance(predicted_lang, str)
54+
assert 0 <= prob <= 1.0
55+
56+
#Exact matching may depend on the model, so we compare based on presence in the top predictions.
57+
assert predicted_lang == expected_lang, f"Expected: {expected_lang}, got: {predicted_lang}"

0 commit comments

Comments
 (0)