1616import os
1717import re
1818from typing import Dict , List , Optional
19+ import tempfile
20+ import shutil
21+ import requests
22+ import wget
23+ import tarfile
24+ from glob import glob
1925
2026import soundfile
2127import torchaudio
@@ -1316,4 +1322,163 @@ def process_dataset_entry(self, data_entry):
13161322 audio , sr = librosa .load (path = audio_file , sr = self .sample_rate , duration = self .max_seconds )
13171323 bandwidth = self ._estimate_bandwidth (audio = audio , sample_rate = sr )
13181324 data_entry [self .output_bandwidth_key ] = int (bandwidth )
1325+ return [DataEntry (data = data_entry )]
1326+
1327+
1328+ class CharacterHistogramLangValidator (BaseParallelProcessor ):
1329+ """
1330+ A processor that filters text based on character histogram similarity to trusted data in the target language.
1331+
1332+ This processor computes the ratio of characters in a given text that are found in a reference character histogram
1333+ for a specific language. If this ratio is below a certain threshold, the text is likely mislabeled or noisy.
1334+
1335+ Histograms are sourced from the NLLB paper (https://arxiv.org/pdf/2207.04672), see page 30 for methodology. This
1336+ technique is a lightweight language ID filter, designed to catch mismatches between text content and claimed language.
1337+
1338+ Reference implementation: https://github.com/facebookresearch/fairseq/blob/main/examples/m2m_100/process_data/clean_histogram.py
1339+
1340+ Args:
1341+ text_field (str): Key in the data entry containing the text to evaluate.
1342+ lang_field (str, optional): Key in the data entry that identifies the language. Required if `lang` is not provided.
1343+ lang (str, optional): Language code to use for all entries (overrides `lang_field`). Required if `lang_field` is not provided.
1344+ threshold (float): Threshold ratio to determine if text matches the histogram. Used only externally (not enforced in this processor).
1345+ cache_dir (str, optional): Directory where histograms are downloaded and cached.
1346+ threshold_char (str): Character used to truncate the histogram file (default is ']').
1347+ output_score_field (str): Key name under which the computed character match ratio will be stored.
1348+ **kwargs: Additional keyword arguments passed to `BaseParallelProcessor`.
1349+
1350+ Raises:
1351+ ValueError: If both `lang` and `lang_field` are provided, or if neither is provided.
1352+ Also raised if histogram for specified language is missing.
1353+
1354+ Returns:
1355+ A manifest where each entry includes the additional field `output_score_field` with the character match ratio.
1356+ Example::
1357+
1358+ {
1359+ "text": "hello world",
1360+ "lang": "en",
1361+ "hist_token_ratio": 0.95
1362+ }
1363+ """
1364+
1365+ HISTOGRAMS_URL = 'https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz'
1366+
1367+ def __init__ (self ,
1368+ text_field : str ,
1369+ lang_field : str = None ,
1370+ lang : str = None ,
1371+ threshold : float = 0.8 ,
1372+ cache_dir : str = None ,
1373+ threshold_char : str = "]" ,
1374+ output_score_field : str = "hist_token_ratio" ,
1375+ ** kwargs ):
1376+ super ().__init__ (** kwargs )
1377+ self .text_field = text_field
1378+
1379+ # Ensure exactly one of `lang` or `lang_field` is provided
1380+ if lang_field is None and lang is None :
1381+ raise ValueError ("One of the arguments `lang` or `lang_field` must be provided." )
1382+ if lang_field is not None and lang is not None :
1383+ raise ValueError (
1384+ f"Both `lang` ({ lang } ) and `lang_field` ({ lang_field } ) are provided, which makes the source of language ambiguous. Please provide only one of them."
1385+ )
1386+
1387+ self .lang_field = lang_field
1388+ self .lang = lang
1389+ self .threshold = threshold
1390+ self .cache_dir = cache_dir
1391+ self .threshold_char = threshold_char
1392+ self .output_score_field = output_score_field
1393+ self .histograms = dict ()
1394+
1395+ def _read_hist (self , lang : str ):
1396+ """
1397+ Read and parse the histogram file for a given language, stopping at the threshold character.
1398+ """
1399+ hist_file = os .path .join (self .cache_dir , lang )
1400+ chars = []
1401+ with open (hist_file ) as hist :
1402+ for line in hist :
1403+ char = line [0 ]
1404+ chars .append (char )
1405+ if char == self .threshold_char :
1406+ break
1407+ self .histograms [lang ] = set (chars )
1408+
1409+ def _download_histograms (self ):
1410+ """
1411+ Download and extract histogram files into the cache directory.
1412+ """
1413+ logger .info ('Downloading histograms collection..' )
1414+ response = requests .get (self .HISTOGRAMS_URL )
1415+ if response .status_code != 200 :
1416+ raise requests .exceptions .RequestException (
1417+ f"Failed to download model file. Status code: { response .status_code } "
1418+ )
1419+
1420+ if self .cache_dir is None :
1421+ self .cache_dir = tempfile .mkdtemp ()
1422+
1423+ os .makedirs (self .cache_dir , exist_ok = True )
1424+
1425+ histograms_tarfile = wget .download (self .HISTOGRAMS_URL , out = self .cache_dir )
1426+ with tarfile .open (histograms_tarfile , "r:gz" ) as tar :
1427+ tar .extractall (path = self .cache_dir )
1428+
1429+ # Flatten subdirectories into the main cache_dir
1430+ histograms_filepaths = glob (f'{ self .cache_dir } /checkpoint/edunov/cc60_multilingual/clean_hists/*' )
1431+ for histogram_filepath in histograms_filepaths :
1432+ shutil .move (histogram_filepath , os .path .join (self .cache_dir , os .path .basename (histogram_filepath )))
1433+
1434+ os .remove (histograms_tarfile )
1435+ shutil .rmtree (f'{ self .cache_dir } /checkpoint/edunov/cc60_multilingual/clean_hists/' )
1436+ logger .info (f'Histograms have been downloaded to { self .cache_dir } .' )
1437+
1438+ def prepare (self ):
1439+ """
1440+ Ensure histograms are available and read them into memory.
1441+ """
1442+ if (self .cache_dir is None or
1443+ not os .path .exists (self .cache_dir ) or
1444+ not os .path .isdir (self .cache_dir ) or
1445+ len (os .listdir (self .cache_dir )) == 0 ):
1446+
1447+ self ._download_histograms ()
1448+
1449+ logger .info ('Reading histograms...' )
1450+ available_langs = os .listdir (self .cache_dir )
1451+ if self .lang is not None :
1452+ if self .lang in available_langs :
1453+ self ._read_hist (self .lang )
1454+ else :
1455+ raise ValueError (f"Invalid value for `lang`: { self .lang } . Please provide one of the following: { available_langs } " )
1456+ logger .info (f'Histogram for `{ self .lang } ` has been read.' )
1457+ else :
1458+ for lang in tqdm (available_langs ):
1459+ self ._read_hist (lang )
1460+ logger .info (f'Histograms have been read.' )
1461+
1462+ def process_dataset_entry (self , data_entry ):
1463+ """
1464+ Compute and attach the character histogram match ratio for a given text entry.
1465+
1466+ Args:
1467+ data_entry (dict): A dictionary containing at least `text_field` and either `lang_field` or a preset `lang`.
1468+
1469+ Returns:
1470+ List[DataEntry]: A list with one updated `DataEntry` including the character match ratio field.
1471+ """
1472+ # Determine language for this entry
1473+ lang = self .lang if self .lang is not None else data_entry [self .lang_field ]
1474+ if lang not in self .histograms :
1475+ raise ValueError (f'lang `{ lang } ` is not supported.' )
1476+
1477+ # Compute how many characters match the histogram
1478+ text = data_entry [self .text_field ].strip ()
1479+ cnt = len ([c for c in text if c in self .histograms [lang ]])
1480+ token_ratio = cnt / len (text ) if len (text ) > 0 else 0.0
1481+
1482+ # Store the ratio in the data entry
1483+ data_entry [self .output_score_field ] = token_ratio
13191484 return [DataEntry (data = data_entry )]
0 commit comments