1+ # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ import os
16+ import requests
17+ import tempfile
18+ import wget
19+
20+ from sdp .logging import logger
21+ from sdp .processors .base_processor import BaseParallelProcessor , DataEntry
22+
23+
24+ class FastTextLangIdClassifier (BaseParallelProcessor ):
25+ """
26+ This processor supports language identification using pretrained FastText models.
27+ It classifies text and adds the predicted label and probability to the dataset entry.
28+ If needed, it downloads the model, loads it into memory, and performs prediction on the
29+ specified input text field.
30+
31+ Args:
32+ model_name_or_path (str): Path to a FastText model file or the name of a supported remote model
33+ ('lid.176.bin' or 'lid.176.ftz').
34+ text_field (str): The name of the field in the dataset entry that contains the input text for classification.
35+ output_field (str): The name of the field to store the predicted label. Defaults to "label".
36+ top_k (int): The number of top predictions to return. Defaults to 1 (-1 for all).
37+ cache_dir (str, optional): Directory to store the downloaded model file. If not provided, a temporary
38+ directory is used.
39+ **kwargs: Additional keyword arguments passed to `BaseParallelProcessor`.
40+
41+ Returns:
42+ A manifest where each entry contains the original data fields plus
43+ - `<output_field>`: The predicted label (e.g., language code for `lid.176.bin`),
44+ - `<output_field>_prob`: The probability of the prediction.
45+
46+ Note:
47+ Make sure to install `fasttext` before using this processor:
48+ `pip install fasttext`
49+ """
50+
51+ SUPPROTED_MODELS_URLS = {
52+ 'lid.176.bin' : 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin' ,
53+ 'lid.176.ftz' : 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz'
54+ }
55+
56+ def __init__ (
57+ self ,
58+ model_name_or_path : str ,
59+ text_field : str ,
60+ output_field : str = "label" ,
61+ top_k : int = 1 ,
62+ cache_dir : str = None ,
63+ ** kwargs
64+ ):
65+ super ().__init__ (** kwargs )
66+ self .model_name_or_path = model_name_or_path
67+ self .text_field = text_field
68+ self .output_field = output_field
69+ self .cache_dir = cache_dir
70+ self .top_k = top_k
71+ self ._model = None
72+
73+ def _download_model (self ):
74+ """Downloads the FastText model from a predefined URL and stores it in the cache directory."""
75+ model_url = self .SUPPROTED_MODELS_URLS [self .model_name_or_path ]
76+ logger .info (f'Downloading { self .model_name_or_path } ..' )
77+ response = requests .get (model_url )
78+
79+ if response .status_code != 200 :
80+ raise requests .exceptions .RequestException (
81+ f"Failed to download model file. Status code: { response .status_code } "
82+ )
83+
84+ if self .cache_dir is None :
85+ self .cache_dir = tempfile .mkdtemp ()
86+ os .makedirs (self .cache_dir , exist_ok = True )
87+
88+ self .model_name_or_path = wget .download (model_url , out = self .cache_dir )
89+ logger .info (f'Model `{ self .model_name_or_path } ` has been downloaded to { self .cache_dir } .' )
90+
91+ def prepare (self ):
92+ """
93+ Prepares the model for classification:
94+ - Checks if the model file exists locally.
95+ - Downloads the model if only the name is given and it's known.
96+ - Raises ValueError if the path or model name is invalid.
97+ """
98+ import fasttext
99+
100+ if not os .path .exists (self .model_name_or_path ):
101+ if self .cache_dir and os .path .exists (os .path .join (self .cache_dir , self .model_name_or_path )):
102+ self .model_name_or_path = os .path .join (self .cache_dir , self .model_name_or_path )
103+ elif self .model_name_or_path in self .SUPPROTED_MODELS_URLS :
104+ self ._download_model ()
105+ else :
106+ raise ValueError (f'Current model is not supported or filepath is invalid: { self .model_name_or_path } .' )
107+
108+ self ._model = fasttext .load_model (self .model_name_or_path )
109+
110+ def process_dataset_entry (self , data_entry : dict ):
111+ """Applies the classifier to a single dataset entry."""
112+ text = data_entry [self .text_field ].strip ().replace ("\n " , " " )
113+ label , prob = self ._model .predict (text )
114+ if self .top_k == 1 :
115+ data_entry [self .output_field ] = label [0 ].replace ('__label__' , '' )
116+ data_entry [f"{ self .output_field } _prob" ] = prob [0 ]
117+ else :
118+ max_k = len (label ) if self .top_k == - 1 else self .top_k
119+
120+ for _label , _prob , top_i in zip (label , prob , range (1 , max_k + 1 )):
121+ data_entry [f"{ self .output_field } _{ top_i } " ] = _label .replace ('__label__' , '' )
122+ data_entry [f"{ self .output_field } _prob_{ top_i } " ] = _prob
123+
124+ return [DataEntry (data = data_entry )]
0 commit comments