1+ # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ import os
16+ import requests
17+ import tempfile
18+ import wget
19+
20+ from sdp .logging import logger
21+ from sdp .processors .base_processor import BaseParallelProcessor , DataEntry
22+
23+
24+ class FastTextLangIdClassifier (BaseParallelProcessor ):
25+ """
26+ This processor supports language identification using pretrained FastText models.
27+ It classifies text and adds the predicted label and probability to the dataset entry.
28+ If needed, it downloads the model, loads it into memory, and performs prediction on the
29+ specified input text field.
30+
31+ Args:
32+ model_name_or_path (str): Path to a FastText model file or the name of a supported remote model ('lid.176.bin' or 'lid.176.ftz').
33+ text_field (str): The name of the field in the dataset entry that contains the input text for classification.
34+ output_field (str): The name of the field to store the predicted label. Defaults to "label".
35+ cache_dir (str, optional): Directory to store the downloaded model file. If not provided, a temporary directory is used.
36+ **kwargs: Additional keyword arguments passed to `BaseParallelProcessor`.
37+
38+ Returns:
39+ A manifest where each entry contains the original data fields plus:
40+ - `<output_field>`: The predicted label (e.g., language code for `lid.176.bin`).
41+ - `<output_field>_prob`: The probability of the prediction.
42+
43+ .. note::
44+ Make sure to install `fasttext` before using this processor:
45+ pip install fasttext
46+
47+ """
48+
49+ SUPPROTED_MODELS_URLS = {
50+ 'lid.176.bin' : 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin' ,
51+ 'lid.176.ftz' : 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz'
52+ }
53+
54+ def __init__ (
55+ self ,
56+ model_name_or_path : str ,
57+ text_field : str ,
58+ output_field : str = "label" ,
59+ cache_dir : str = None ,
60+ ** kwargs
61+ ):
62+ super ().__init__ (** kwargs )
63+ self .model_name_or_path = model_name_or_path
64+ self .text_field = text_field
65+ self .output_field = output_field
66+ self .cache_dir = cache_dir
67+ self ._model = None
68+
69+ def _download_model (self ):
70+ """Downloads the FastText model from a predefined URL and stores it in the cache directory."""
71+ model_url = self .SUPPROTED_MODELS_URLS [self .model_name_or_path ]
72+ logger .info (f'Downloading { self .model_name_or_path } ..' )
73+ response = requests .get (model_url )
74+
75+ if response .status_code != 200 :
76+ raise requests .exceptions .RequestException (
77+ f"Failed to download model file. Status code: { response .status_code } "
78+ )
79+
80+ if self .cache_dir is None :
81+ self .cache_dir = tempfile .mkdtemp ()
82+ os .makedirs (self .cache_dir , exist_ok = True )
83+
84+ self .model_name_or_path = wget .download (model_url , out = self .cache_dir )
85+ logger .info (f'Model `{ self .model_name_or_path } ` has been downloaded to { self .cache_dir } .' )
86+
87+ def prepare (self ):
88+ """
89+ Prepares the model for classification:
90+ - Checks if the model file exists locally.
91+ - Downloads the model if only the name is given and it's known.
92+ - Raises ValueError if the path or model name is invalid.
93+ """
94+ import fasttext
95+
96+ if not os .path .exists (self .model_name_or_path ):
97+ if self .cache_dir and os .path .exists (os .path .join (self .cache_dir , self .model_name_or_path )):
98+ self .model_name_or_path = os .path .join (self .cache_dir , self .model_name_or_path )
99+ elif self .model_name_or_path in self .SUPPROTED_MODELS_URLS :
100+ self ._download_model ()
101+ else :
102+ raise ValueError (f'Current model is not supported or filepath is invalid: { self .model_name_or_path } .' )
103+
104+ self ._model = fasttext .load_model (self .model_name_or_path )
105+
106+ def process_dataset_entry (self , data_entry : dict ):
107+ """Applies the classifier to a single dataset entry."""
108+ text = data_entry [self .text_field ].strip ().replace ("\n " , " " )
109+ label , prob = self ._model .predict (text )
110+ data_entry [self .output_field ] = label [0 ].replace ('__label__' , '' )
111+ data_entry [f"{ self .output_field } _prob" ] = prob [0 ]
112+
113+ return [DataEntry (data = data_entry )]
0 commit comments