diff --git a/docs/src/sdp/api.rst b/docs/src/sdp/api.rst index 9d1cb59c..25ff6a2a 100644 --- a/docs/src/sdp/api.rst +++ b/docs/src/sdp/api.rst @@ -261,6 +261,9 @@ Data modifications .. autodata:: sdp.processors.EstimateBandwidth :annotation: +.. autodata:: sdp.processors.CharacterHistogramLangValidator + :annotation: + Data filtering '''''''''''''' diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index dce8e584..b146cfa7 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -119,6 +119,7 @@ ListToEntries, LambdaExpression, EstimateBandwidth, + CharacterHistogramLangValidator, ) from sdp.processors.modify_manifest.data_to_dropbool import ( DropASRError, diff --git a/sdp/processors/modify_manifest/data_to_data.py b/sdp/processors/modify_manifest/data_to_data.py index 09a55011..88405d2c 100644 --- a/sdp/processors/modify_manifest/data_to_data.py +++ b/sdp/processors/modify_manifest/data_to_data.py @@ -16,6 +16,12 @@ import os import re from typing import Dict, List, Optional +import tempfile +import shutil +import requests +import wget +import tarfile +from glob import glob import soundfile import torchaudio @@ -1316,4 +1322,163 @@ def process_dataset_entry(self, data_entry): audio, sr = librosa.load(path=audio_file, sr=self.sample_rate, duration=self.max_seconds) bandwidth = self._estimate_bandwidth(audio=audio, sample_rate=sr) data_entry[self.output_bandwidth_key] = int(bandwidth) + return [DataEntry(data=data_entry)] + + +class CharacterHistogramLangValidator(BaseParallelProcessor): + """ + A processor that filters text based on character histogram similarity to trusted data in the target language. + + This processor computes the ratio of characters in a given text that are found in a reference character histogram + for a specific language. If this ratio is below a certain threshold, the text is likely mislabeled or noisy. + + Histograms are sourced from the NLLB paper (https://arxiv.org/pdf/2207.04672), see page 30 for methodology. This + technique is a lightweight language ID filter, designed to catch mismatches between text content and claimed language. + + Reference implementation: https://github.com/facebookresearch/fairseq/blob/main/examples/m2m_100/process_data/clean_histogram.py + + Args: + text_field (str): Key in the data entry containing the text to evaluate. + lang_field (str, optional): Key in the data entry that identifies the language. Required if `lang` is not provided. + lang (str, optional): Language code to use for all entries (overrides `lang_field`). Required if `lang_field` is not provided. + threshold (float): Threshold ratio to determine if text matches the histogram. Used only externally (not enforced in this processor). + cache_dir (str, optional): Directory where histograms are downloaded and cached. + threshold_char (str): Character used to truncate the histogram file (default is ']'). + output_score_field (str): Key name under which the computed character match ratio will be stored. + **kwargs: Additional keyword arguments passed to `BaseParallelProcessor`. + + Raises: + ValueError: If both `lang` and `lang_field` are provided, or if neither is provided. + Also raised if histogram for specified language is missing. + + Returns: + A manifest where each entry includes the additional field `output_score_field` with the character match ratio. + Example:: + + { + "text": "hello world", + "lang": "en", + "hist_token_ratio": 0.95 + } + """ + + HISTOGRAMS_URL = 'https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz' + + def __init__(self, + text_field: str, + lang_field: str = None, + lang: str = None, + threshold: float = 0.8, + cache_dir: str = None, + threshold_char: str = "]", + output_score_field: str = "hist_token_ratio", + **kwargs): + super().__init__(**kwargs) + self.text_field = text_field + + # Ensure exactly one of `lang` or `lang_field` is provided + if lang_field is None and lang is None: + raise ValueError("One of the arguments `lang` or `lang_field` must be provided.") + if lang_field is not None and lang is not None: + raise ValueError( + f"Both `lang` ({lang}) and `lang_field` ({lang_field}) are provided, which makes the source of language ambiguous. Please provide only one of them." + ) + + self.lang_field = lang_field + self.lang = lang + self.threshold = threshold + self.cache_dir = cache_dir + self.threshold_char = threshold_char + self.output_score_field = output_score_field + self.histograms = dict() + + def _read_hist(self, lang: str): + """ + Read and parse the histogram file for a given language, stopping at the threshold character. + """ + hist_file = os.path.join(self.cache_dir, lang) + chars = [] + with open(hist_file) as hist: + for line in hist: + char = line[0] + chars.append(char) + if char == self.threshold_char: + break + self.histograms[lang] = set(chars) + + def _download_histograms(self): + """ + Download and extract histogram files into the cache directory. + """ + logger.info('Downloading histograms collection..') + response = requests.get(self.HISTOGRAMS_URL) + if response.status_code != 200: + raise requests.exceptions.RequestException( + f"Failed to download model file. Status code: {response.status_code}" + ) + + if self.cache_dir is None: + self.cache_dir = tempfile.mkdtemp() + + os.makedirs(self.cache_dir, exist_ok=True) + + histograms_tarfile = wget.download(self.HISTOGRAMS_URL, out=self.cache_dir) + with tarfile.open(histograms_tarfile, "r:gz") as tar: + tar.extractall(path=self.cache_dir) + + # Flatten subdirectories into the main cache_dir + histograms_filepaths = glob(f'{self.cache_dir}/checkpoint/edunov/cc60_multilingual/clean_hists/*') + for histogram_filepath in histograms_filepaths: + shutil.move(histogram_filepath, os.path.join(self.cache_dir, os.path.basename(histogram_filepath))) + + os.remove(histograms_tarfile) + shutil.rmtree(f'{self.cache_dir}/checkpoint/edunov/cc60_multilingual/clean_hists/') + logger.info(f'Histograms have been downloaded to {self.cache_dir}.') + + def prepare(self): + """ + Ensure histograms are available and read them into memory. + """ + if (self.cache_dir is None or + not os.path.exists(self.cache_dir) or + not os.path.isdir(self.cache_dir) or + len(os.listdir(self.cache_dir)) == 0): + + self._download_histograms() + + logger.info('Reading histograms...') + available_langs = os.listdir(self.cache_dir) + if self.lang is not None: + if self.lang in available_langs: + self._read_hist(self.lang) + else: + raise ValueError(f"Invalid value for `lang`: {self.lang}. Please provide one of the following: {available_langs}") + logger.info(f'Histogram for `{self.lang}` has been read.') + else: + for lang in tqdm(available_langs): + self._read_hist(lang) + logger.info(f'Histograms have been read.') + + def process_dataset_entry(self, data_entry): + """ + Compute and attach the character histogram match ratio for a given text entry. + + Args: + data_entry (dict): A dictionary containing at least `text_field` and either `lang_field` or a preset `lang`. + + Returns: + List[DataEntry]: A list with one updated `DataEntry` including the character match ratio field. + """ + # Determine language for this entry + lang = self.lang if self.lang is not None else data_entry[self.lang_field] + if lang not in self.histograms: + raise ValueError(f'lang `{lang}` is not supported.') + + # Compute how many characters match the histogram + text = data_entry[self.text_field].strip() + cnt = len([c for c in text if c in self.histograms[lang]]) + token_ratio = cnt / len(text) if len(text) > 0 else 0.0 + + # Store the ratio in the data entry + data_entry[self.output_score_field] = token_ratio return [DataEntry(data=data_entry)] \ No newline at end of file diff --git a/tests/test_data_to_data.py b/tests/test_data_to_data.py index 9dd3278a..b9f2007f 100644 --- a/tests/test_data_to_data.py +++ b/tests/test_data_to_data.py @@ -13,6 +13,9 @@ # limitations under the License. import pytest +import os +import boto3 +from botocore.exceptions import ClientError from sdp.processors.modify_manifest.data_to_data import ( InsIfASRInsertion, @@ -21,6 +24,7 @@ SubRegex, ListToEntries, LambdaExpression, + CharacterHistogramLangValidator, ) from sdp.processors.inference.llm.utils.qwen_cleaning import CleanQwenGeneration @@ -282,9 +286,65 @@ def test_detect_whisper_hallucinations(tmp_path, text, expected_flags): for key, value in expected_flags.items(): assert result_entry[key] == value, f"Failed for text='{text}' on key='{key}'" +@pytest.fixture(scope="session") +def en_hist_dir(tmp_path_factory): + """ + Download the English histogram from S3 just once + and return the directory path that contains it. + + Uses tmp_path_factory → one persistent temp-dir for the whole session. + """ + s3 = boto3.client('s3', + aws_access_key_id=os.getenv("AWS_ACCESS_KEY"), + aws_secret_access_key=os.getenv("AWS_SECRET_KEY") + ) + + bucket = "sdp-test-data" + key = "test_data/test_processors/CharacterHistogramLangValidator/histograms/en" + + tmp_dir = tmp_path_factory.mktemp("char_hists") + local_path = tmp_dir / "en" + + if not local_path.exists(): + try: + s3.download_file(bucket, key, str(local_path)) + except ClientError as e: + code = e.response.get("Error", {}).get("Code", "") + pytest.skip(f"Cannot download s3://{bucket}/{key} ({code}).") + + assert local_path.exists(), "Histogram file was not downloaded" + return str(tmp_dir) + +@pytest.mark.parametrize( + "text,expected", + [ + # Plain English sentence; all characters expected in 'en' histogram -> ratio 1.0 + ("Hello, how are you today?", 1.0), + # # Chinese characters; none expected in 'en' histogram -> ratio 0.0 + ("今天天气很好,我们去公园吧。", 0.0), + # Symbols + digits; only digits 1..5 expected in 'en' histogram -> 5 matches out of 17 chars + ("@#$%^&*()_+=12345", 5 / 17), # 0.29411764705882354 + # French sentence with one accented char 'é' not in 'en' histogram -> 23 matches out of 24 chars + ("C'est une belle journée.", 23 / 24), # 0.9583333333333334 + ], +) +def test_character_hist_validator(text, expected, en_hist_dir): + processor = CharacterHistogramLangValidator( + text_field="text", + lang="en", + cache_dir=en_hist_dir, + output_manifest_file=None, + ) + processor.prepare() + + entry = {"text": text} + result_entry = processor.process_dataset_entry(entry)[0].data + + assert result_entry[processor.output_score_field] == pytest.approx(expected, rel=1e-12) + @pytest.mark.parametrize("test_class,class_kwargs,test_input,expected_output", test_params_list, ids=str) def test_data_to_data(test_class, class_kwargs, test_input, expected_output): processor = test_class(**class_kwargs, output_manifest_file=None) result = [entry.data for entry in processor.process_dataset_entry(test_input)] - assert result == expected_output \ No newline at end of file + assert result == expected_output