Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/src/sdp/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ Data modifications
.. autodata:: sdp.processors.EstimateBandwidth
:annotation:

.. autodata:: sdp.processors.CharacterHistogramLangValidator
:annotation:

Data filtering
''''''''''''''

Expand Down
1 change: 1 addition & 0 deletions sdp/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
ListToEntries,
LambdaExpression,
EstimateBandwidth,
CharacterHistogramLangValidator,
)
from sdp.processors.modify_manifest.data_to_dropbool import (
DropASRError,
Expand Down
165 changes: 165 additions & 0 deletions sdp/processors/modify_manifest/data_to_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
import os
import re
from typing import Dict, List, Optional
import tempfile
import shutil
import requests
import wget
import tarfile
from glob import glob

import soundfile
import torchaudio
Expand Down Expand Up @@ -1316,4 +1322,163 @@ def process_dataset_entry(self, data_entry):
audio, sr = librosa.load(path=audio_file, sr=self.sample_rate, duration=self.max_seconds)
bandwidth = self._estimate_bandwidth(audio=audio, sample_rate=sr)
data_entry[self.output_bandwidth_key] = int(bandwidth)
return [DataEntry(data=data_entry)]


class CharacterHistogramLangValidator(BaseParallelProcessor):
"""
A processor that filters text based on character histogram similarity to trusted data in the target language.

This processor computes the ratio of characters in a given text that are found in a reference character histogram
for a specific language. If this ratio is below a certain threshold, the text is likely mislabeled or noisy.

Histograms are sourced from the NLLB paper (https://arxiv.org/pdf/2207.04672), see page 30 for methodology. This
technique is a lightweight language ID filter, designed to catch mismatches between text content and claimed language.

Reference implementation: https://github.com/facebookresearch/fairseq/blob/main/examples/m2m_100/process_data/clean_histogram.py

Args:
text_field (str): Key in the data entry containing the text to evaluate.
lang_field (str, optional): Key in the data entry that identifies the language. Required if `lang` is not provided.
lang (str, optional): Language code to use for all entries (overrides `lang_field`). Required if `lang_field` is not provided.
threshold (float): Threshold ratio to determine if text matches the histogram. Used only externally (not enforced in this processor).
cache_dir (str, optional): Directory where histograms are downloaded and cached.
threshold_char (str): Character used to truncate the histogram file (default is ']').
output_score_field (str): Key name under which the computed character match ratio will be stored.
**kwargs: Additional keyword arguments passed to `BaseParallelProcessor`.

Raises:
ValueError: If both `lang` and `lang_field` are provided, or if neither is provided.
Also raised if histogram for specified language is missing.

Returns:
A manifest where each entry includes the additional field `output_score_field` with the character match ratio.
Example::

{
"text": "hello world",
"lang": "en",
"hist_token_ratio": 0.95
}
"""

HISTOGRAMS_URL = 'https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz'

def __init__(self,
text_field: str,
lang_field: str = None,
lang: str = None,
threshold: float = 0.8,
cache_dir: str = None,
threshold_char: str = "]",
output_score_field: str = "hist_token_ratio",
**kwargs):
super().__init__(**kwargs)
self.text_field = text_field

# Ensure exactly one of `lang` or `lang_field` is provided
if lang_field is None and lang is None:
raise ValueError("One of the arguments `lang` or `lang_field` must be provided.")
if lang_field is not None and lang is not None:
raise ValueError(
f"Both `lang` ({lang}) and `lang_field` ({lang_field}) are provided, which makes the source of language ambiguous. Please provide only one of them."
)

self.lang_field = lang_field
self.lang = lang
self.threshold = threshold
self.cache_dir = cache_dir
self.threshold_char = threshold_char
self.output_score_field = output_score_field
self.histograms = dict()

def _read_hist(self, lang: str):
"""
Read and parse the histogram file for a given language, stopping at the threshold character.
"""
hist_file = os.path.join(self.cache_dir, lang)
chars = []
with open(hist_file) as hist:
for line in hist:
char = line[0]
chars.append(char)
if char == self.threshold_char:
break
self.histograms[lang] = set(chars)

def _download_histograms(self):
"""
Download and extract histogram files into the cache directory.
"""
logger.info('Downloading histograms collection..')
response = requests.get(self.HISTOGRAMS_URL)
if response.status_code != 200:
raise requests.exceptions.RequestException(
f"Failed to download model file. Status code: {response.status_code}"
)

if self.cache_dir is None:
self.cache_dir = tempfile.mkdtemp()

os.makedirs(self.cache_dir, exist_ok=True)

histograms_tarfile = wget.download(self.HISTOGRAMS_URL, out=self.cache_dir)
with tarfile.open(histograms_tarfile, "r:gz") as tar:
tar.extractall(path=self.cache_dir)

# Flatten subdirectories into the main cache_dir
histograms_filepaths = glob(f'{self.cache_dir}/checkpoint/edunov/cc60_multilingual/clean_hists/*')
for histogram_filepath in histograms_filepaths:
shutil.move(histogram_filepath, os.path.join(self.cache_dir, os.path.basename(histogram_filepath)))

os.remove(histograms_tarfile)
shutil.rmtree(f'{self.cache_dir}/checkpoint/edunov/cc60_multilingual/clean_hists/')
logger.info(f'Histograms have been downloaded to {self.cache_dir}.')

def prepare(self):
"""
Ensure histograms are available and read them into memory.
"""
if (self.cache_dir is None or
not os.path.exists(self.cache_dir) or
not os.path.isdir(self.cache_dir) or
len(os.listdir(self.cache_dir)) == 0):

self._download_histograms()

logger.info('Reading histograms...')
available_langs = os.listdir(self.cache_dir)
if self.lang is not None:
if self.lang in available_langs:
self._read_hist(self.lang)
else:
raise ValueError(f"Invalid value for `lang`: {self.lang}. Please provide one of the following: {available_langs}")
logger.info(f'Histogram for `{self.lang}` has been read.')
else:
for lang in tqdm(available_langs):
self._read_hist(lang)
logger.info(f'Histograms have been read.')

def process_dataset_entry(self, data_entry):
"""
Compute and attach the character histogram match ratio for a given text entry.

Args:
data_entry (dict): A dictionary containing at least `text_field` and either `lang_field` or a preset `lang`.

Returns:
List[DataEntry]: A list with one updated `DataEntry` including the character match ratio field.
"""
# Determine language for this entry
lang = self.lang if self.lang is not None else data_entry[self.lang_field]
if lang not in self.histograms:
raise ValueError(f'lang `{lang}` is not supported.')

# Compute how many characters match the histogram
text = data_entry[self.text_field].strip()
cnt = len([c for c in text if c in self.histograms[lang]])
token_ratio = cnt / len(text) if len(text) > 0 else 0.0

# Store the ratio in the data entry
data_entry[self.output_score_field] = token_ratio
return [DataEntry(data=data_entry)]
62 changes: 61 additions & 1 deletion tests/test_data_to_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

import pytest
import os
import boto3
from botocore.exceptions import ClientError

from sdp.processors.modify_manifest.data_to_data import (
InsIfASRInsertion,
Expand All @@ -21,6 +24,7 @@
SubRegex,
ListToEntries,
LambdaExpression,
CharacterHistogramLangValidator,
)

from sdp.processors.inference.llm.utils.qwen_cleaning import CleanQwenGeneration
Expand Down Expand Up @@ -282,9 +286,65 @@ def test_detect_whisper_hallucinations(tmp_path, text, expected_flags):
for key, value in expected_flags.items():
assert result_entry[key] == value, f"Failed for text='{text}' on key='{key}'"

@pytest.fixture(scope="session")
def en_hist_dir(tmp_path_factory):
"""
Download the English histogram from S3 just once
and return the directory path that contains it.

Uses tmp_path_factory → one persistent temp-dir for the whole session.
"""
s3 = boto3.client('s3',
aws_access_key_id=os.getenv("AWS_ACCESS_KEY"),
aws_secret_access_key=os.getenv("AWS_SECRET_KEY")
)

bucket = "sdp-test-data"
key = "test_data/test_processors/CharacterHistogramLangValidator/histograms/en"

tmp_dir = tmp_path_factory.mktemp("char_hists")
local_path = tmp_dir / "en"

if not local_path.exists():
try:
s3.download_file(bucket, key, str(local_path))
except ClientError as e:
code = e.response.get("Error", {}).get("Code", "")
pytest.skip(f"Cannot download s3://{bucket}/{key} ({code}).")

assert local_path.exists(), "Histogram file was not downloaded"
return str(tmp_dir)

@pytest.mark.parametrize(
"text,expected",
[
# Plain English sentence; all characters expected in 'en' histogram -> ratio 1.0
("Hello, how are you today?", 1.0),
# # Chinese characters; none expected in 'en' histogram -> ratio 0.0
("今天天气很好,我们去公园吧。", 0.0),
# Symbols + digits; only digits 1..5 expected in 'en' histogram -> 5 matches out of 17 chars
("@#$%^&*()_+=12345", 5 / 17), # 0.29411764705882354
# French sentence with one accented char 'é' not in 'en' histogram -> 23 matches out of 24 chars
("C'est une belle journée.", 23 / 24), # 0.9583333333333334
],
)
def test_character_hist_validator(text, expected, en_hist_dir):
processor = CharacterHistogramLangValidator(
text_field="text",
lang="en",
cache_dir=en_hist_dir,
output_manifest_file=None,
)
processor.prepare()

entry = {"text": text}
result_entry = processor.process_dataset_entry(entry)[0].data

assert result_entry[processor.output_score_field] == pytest.approx(expected, rel=1e-12)

@pytest.mark.parametrize("test_class,class_kwargs,test_input,expected_output", test_params_list, ids=str)
def test_data_to_data(test_class, class_kwargs, test_input, expected_output):
processor = test_class(**class_kwargs, output_manifest_file=None)
result = [entry.data for entry in processor.process_dataset_entry(test_input)]

assert result == expected_output
assert result == expected_output
Loading