1313# limitations under the License.
1414
1515import re
16+ import os
17+ import tempfile
18+ import shutil
19+ import requests
20+ import wget
21+ import tarfile
22+ from tqdm import tqdm
1623
1724from sdp .logging import logger
1825from sdp .processors .base_processor import BaseParallelProcessor , DataEntry
@@ -52,4 +59,93 @@ def process_dataset_entry(self, data_entry):
5259 words = cleaned_string .split ()
5360 num_words = len (words )
5461 data_entry [self .num_words_key ] = num_words
62+ return [DataEntry (data = data_entry )]
63+
64+
65+ class CharacterHistograms (BaseParallelProcessor ):
66+ def __init__ (self ,
67+ text_field : str ,
68+ lang_field : str = None ,
69+ lang : str = None ,
70+ threshold : float = 0.8 ,
71+ cache_dir : str = None ,
72+ threshold_char : str = "]" ,
73+ output_score_field : str = "hist_token_ratio" ,
74+ ** kwargs ):
75+ super ().__init__ (** kwargs )
76+ self .text_field = text_field
77+
78+ if lang_field is None and lang is None :
79+ raise ValueError ("One of the arguments `lang` or `lang_field` must be provided." )
80+
81+ if lang_field is not None and lang is not None :
82+ raise ValueError (
83+ f"Both `lang` ({ lang } ) and `lang_field` ({ lang_field } ) are provided, which makes the source of language ambiguous. Please provide only one of them."
84+ )
85+
86+ self .text_field = text_field
87+ self .lang_field = lang_field
88+ self .lang = lang
89+ self .threshold = threshold
90+ self .cache_dir = cache_dir
91+ self .threshold_char = threshold_char
92+ self .output_score_field = output_score_field
93+ self .histograms = dict ()
94+
95+ def _read_hist (self , lang : str ):
96+ hist_file = os .path .join (self .cache_dir , lang )
97+ chars = []
98+ with open (hist_file ) as hist :
99+ for line in hist :
100+ char = line [0 ]
101+ chars .append (char )
102+ if char == self .threshold_char :
103+ break
104+ self .histograms [lang ] = set (chars )
105+
106+ def prepare (self ):
107+ if self .cache_dir is None :
108+ self .cache_dir = tempfile .mkdtemp ()
109+
110+ os .makedirs (self .cache_dir , exist_ok = True )
111+
112+ if not os .path .exists (self .cache_dir ):
113+ logger .info (f'Downloading histograms to { self .cache_dir } ' )
114+ histograms_url = 'https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz'
115+ response = requests .get (histograms_url )
116+
117+ if response .status_code != 200 :
118+ raise requests .exceptions .RequestException (
119+ f"Failed to download histogram file. Status code: { response .status_code } "
120+ )
121+
122+ histograms_tarfile = wget .download (histograms_url , out = self .cache_dir )
123+ with tarfile .open (histograms_tarfile , "r:gz" ) as tar :
124+ tar .extractall (path = self .cache_dir )
125+
126+ self .cache_dir = os .path .join (self .cache_dir , "checkpoint/edunov/cc60_multilingual/clean_hists" )
127+ logger .info (f'Histograms are downloaded.' )
128+
129+ logger .info (f'Reading histograms' )
130+ available_langs = os .listdir (self .cache_dir )
131+ if self .lang is not None :
132+ if self .lang in available_langs :
133+ self ._read_hist (self .lang )
134+ else :
135+ raise ValueError (f"Invalid value for `lang`: { self .lang } . Please provide one of the following: { available_langs } " )
136+ logger .info (f'Histogram for `{ self .lang } ` has been read.' )
137+ else :
138+ for lang in tqdm (available_langs ):
139+ self ._read_hist (lang )
140+ logger .info (f'Histograms have been read.' )
141+
142+ def process_dataset_entry (self , data_entry ):
143+ lang = self .lang if self .lang is not None else data_entry [self .lang_field ]
144+ if lang not in self .histograms :
145+ raise ValueError (f'lang `{ lang } is not supported.' )
146+
147+ text = data_entry [self .text_field ].strip ()
148+ cnt = len ([c for c in text if c in self .histograms [lang ]])
149+ token_ratio = 1 if cnt / len (text ) > self ._threshold else 0
150+ data_entry [self .output_score_field ] = token_ratio
55151 return [DataEntry (data = data_entry )]
0 commit comments