Skip to content

Commit d7d4927

Browse files
ssh-meisterroot
andauthored
CharacterHistogramLangValidator processor implementation (#154)
* CharacterHistogramLangValidator processor implementation Signed-off-by: Sasha Meister <[email protected]> * Tests added Signed-off-by: root <[email protected]> * tmp_change Signed-off-by: root <[email protected]> * Tmp check Signed-off-by: Sasha Meister <[email protected]> * new test s3 key Signed-off-by: Sasha Meister <[email protected]> * Turn on all tests Signed-off-by: Sasha Meister <[email protected]> * From s3_client to s3_resource * Fix test Signed-off-by: Sasha Meister <[email protected]> * Added try/except to s3 file download Signed-off-by: Sasha Meister <[email protected]> * Removed duplicated row Signed-off-by: Sasha Meister <[email protected]> --------- Signed-off-by: Sasha Meister <[email protected]> Signed-off-by: root <[email protected]> Co-authored-by: root <[email protected]>
1 parent 3fce946 commit d7d4927

File tree

4 files changed

+230
-1
lines changed

4 files changed

+230
-1
lines changed

docs/src/sdp/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,9 @@ Data modifications
261261
.. autodata:: sdp.processors.EstimateBandwidth
262262
:annotation:
263263

264+
.. autodata:: sdp.processors.CharacterHistogramLangValidator
265+
:annotation:
266+
264267
Data filtering
265268
''''''''''''''
266269

sdp/processors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
ListToEntries,
120120
LambdaExpression,
121121
EstimateBandwidth,
122+
CharacterHistogramLangValidator,
122123
)
123124
from sdp.processors.modify_manifest.data_to_dropbool import (
124125
DropASRError,

sdp/processors/modify_manifest/data_to_data.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
import os
1717
import re
1818
from typing import Dict, List, Optional
19+
import tempfile
20+
import shutil
21+
import requests
22+
import wget
23+
import tarfile
24+
from glob import glob
1925

2026
import soundfile
2127
import torchaudio
@@ -1316,4 +1322,163 @@ def process_dataset_entry(self, data_entry):
13161322
audio, sr = librosa.load(path=audio_file, sr=self.sample_rate, duration=self.max_seconds)
13171323
bandwidth = self._estimate_bandwidth(audio=audio, sample_rate=sr)
13181324
data_entry[self.output_bandwidth_key] = int(bandwidth)
1325+
return [DataEntry(data=data_entry)]
1326+
1327+
1328+
class CharacterHistogramLangValidator(BaseParallelProcessor):
1329+
"""
1330+
A processor that filters text based on character histogram similarity to trusted data in the target language.
1331+
1332+
This processor computes the ratio of characters in a given text that are found in a reference character histogram
1333+
for a specific language. If this ratio is below a certain threshold, the text is likely mislabeled or noisy.
1334+
1335+
Histograms are sourced from the NLLB paper (https://arxiv.org/pdf/2207.04672), see page 30 for methodology. This
1336+
technique is a lightweight language ID filter, designed to catch mismatches between text content and claimed language.
1337+
1338+
Reference implementation: https://github.com/facebookresearch/fairseq/blob/main/examples/m2m_100/process_data/clean_histogram.py
1339+
1340+
Args:
1341+
text_field (str): Key in the data entry containing the text to evaluate.
1342+
lang_field (str, optional): Key in the data entry that identifies the language. Required if `lang` is not provided.
1343+
lang (str, optional): Language code to use for all entries (overrides `lang_field`). Required if `lang_field` is not provided.
1344+
threshold (float): Threshold ratio to determine if text matches the histogram. Used only externally (not enforced in this processor).
1345+
cache_dir (str, optional): Directory where histograms are downloaded and cached.
1346+
threshold_char (str): Character used to truncate the histogram file (default is ']').
1347+
output_score_field (str): Key name under which the computed character match ratio will be stored.
1348+
**kwargs: Additional keyword arguments passed to `BaseParallelProcessor`.
1349+
1350+
Raises:
1351+
ValueError: If both `lang` and `lang_field` are provided, or if neither is provided.
1352+
Also raised if histogram for specified language is missing.
1353+
1354+
Returns:
1355+
A manifest where each entry includes the additional field `output_score_field` with the character match ratio.
1356+
Example::
1357+
1358+
{
1359+
"text": "hello world",
1360+
"lang": "en",
1361+
"hist_token_ratio": 0.95
1362+
}
1363+
"""
1364+
1365+
HISTOGRAMS_URL = 'https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz'
1366+
1367+
def __init__(self,
1368+
text_field: str,
1369+
lang_field: str = None,
1370+
lang: str = None,
1371+
threshold: float = 0.8,
1372+
cache_dir: str = None,
1373+
threshold_char: str = "]",
1374+
output_score_field: str = "hist_token_ratio",
1375+
**kwargs):
1376+
super().__init__(**kwargs)
1377+
self.text_field = text_field
1378+
1379+
# Ensure exactly one of `lang` or `lang_field` is provided
1380+
if lang_field is None and lang is None:
1381+
raise ValueError("One of the arguments `lang` or `lang_field` must be provided.")
1382+
if lang_field is not None and lang is not None:
1383+
raise ValueError(
1384+
f"Both `lang` ({lang}) and `lang_field` ({lang_field}) are provided, which makes the source of language ambiguous. Please provide only one of them."
1385+
)
1386+
1387+
self.lang_field = lang_field
1388+
self.lang = lang
1389+
self.threshold = threshold
1390+
self.cache_dir = cache_dir
1391+
self.threshold_char = threshold_char
1392+
self.output_score_field = output_score_field
1393+
self.histograms = dict()
1394+
1395+
def _read_hist(self, lang: str):
1396+
"""
1397+
Read and parse the histogram file for a given language, stopping at the threshold character.
1398+
"""
1399+
hist_file = os.path.join(self.cache_dir, lang)
1400+
chars = []
1401+
with open(hist_file) as hist:
1402+
for line in hist:
1403+
char = line[0]
1404+
chars.append(char)
1405+
if char == self.threshold_char:
1406+
break
1407+
self.histograms[lang] = set(chars)
1408+
1409+
def _download_histograms(self):
1410+
"""
1411+
Download and extract histogram files into the cache directory.
1412+
"""
1413+
logger.info('Downloading histograms collection..')
1414+
response = requests.get(self.HISTOGRAMS_URL)
1415+
if response.status_code != 200:
1416+
raise requests.exceptions.RequestException(
1417+
f"Failed to download model file. Status code: {response.status_code}"
1418+
)
1419+
1420+
if self.cache_dir is None:
1421+
self.cache_dir = tempfile.mkdtemp()
1422+
1423+
os.makedirs(self.cache_dir, exist_ok=True)
1424+
1425+
histograms_tarfile = wget.download(self.HISTOGRAMS_URL, out=self.cache_dir)
1426+
with tarfile.open(histograms_tarfile, "r:gz") as tar:
1427+
tar.extractall(path=self.cache_dir)
1428+
1429+
# Flatten subdirectories into the main cache_dir
1430+
histograms_filepaths = glob(f'{self.cache_dir}/checkpoint/edunov/cc60_multilingual/clean_hists/*')
1431+
for histogram_filepath in histograms_filepaths:
1432+
shutil.move(histogram_filepath, os.path.join(self.cache_dir, os.path.basename(histogram_filepath)))
1433+
1434+
os.remove(histograms_tarfile)
1435+
shutil.rmtree(f'{self.cache_dir}/checkpoint/edunov/cc60_multilingual/clean_hists/')
1436+
logger.info(f'Histograms have been downloaded to {self.cache_dir}.')
1437+
1438+
def prepare(self):
1439+
"""
1440+
Ensure histograms are available and read them into memory.
1441+
"""
1442+
if (self.cache_dir is None or
1443+
not os.path.exists(self.cache_dir) or
1444+
not os.path.isdir(self.cache_dir) or
1445+
len(os.listdir(self.cache_dir)) == 0):
1446+
1447+
self._download_histograms()
1448+
1449+
logger.info('Reading histograms...')
1450+
available_langs = os.listdir(self.cache_dir)
1451+
if self.lang is not None:
1452+
if self.lang in available_langs:
1453+
self._read_hist(self.lang)
1454+
else:
1455+
raise ValueError(f"Invalid value for `lang`: {self.lang}. Please provide one of the following: {available_langs}")
1456+
logger.info(f'Histogram for `{self.lang}` has been read.')
1457+
else:
1458+
for lang in tqdm(available_langs):
1459+
self._read_hist(lang)
1460+
logger.info(f'Histograms have been read.')
1461+
1462+
def process_dataset_entry(self, data_entry):
1463+
"""
1464+
Compute and attach the character histogram match ratio for a given text entry.
1465+
1466+
Args:
1467+
data_entry (dict): A dictionary containing at least `text_field` and either `lang_field` or a preset `lang`.
1468+
1469+
Returns:
1470+
List[DataEntry]: A list with one updated `DataEntry` including the character match ratio field.
1471+
"""
1472+
# Determine language for this entry
1473+
lang = self.lang if self.lang is not None else data_entry[self.lang_field]
1474+
if lang not in self.histograms:
1475+
raise ValueError(f'lang `{lang}` is not supported.')
1476+
1477+
# Compute how many characters match the histogram
1478+
text = data_entry[self.text_field].strip()
1479+
cnt = len([c for c in text if c in self.histograms[lang]])
1480+
token_ratio = cnt / len(text) if len(text) > 0 else 0.0
1481+
1482+
# Store the ratio in the data entry
1483+
data_entry[self.output_score_field] = token_ratio
13191484
return [DataEntry(data=data_entry)]

tests/test_data_to_data.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414

1515
import pytest
16+
import os
17+
import boto3
18+
from botocore.exceptions import ClientError
1619

1720
from sdp.processors.modify_manifest.data_to_data import (
1821
InsIfASRInsertion,
@@ -21,6 +24,7 @@
2124
SubRegex,
2225
ListToEntries,
2326
LambdaExpression,
27+
CharacterHistogramLangValidator,
2428
)
2529

2630
from sdp.processors.inference.llm.utils.qwen_cleaning import CleanQwenGeneration
@@ -282,9 +286,65 @@ def test_detect_whisper_hallucinations(tmp_path, text, expected_flags):
282286
for key, value in expected_flags.items():
283287
assert result_entry[key] == value, f"Failed for text='{text}' on key='{key}'"
284288

289+
@pytest.fixture(scope="session")
290+
def en_hist_dir(tmp_path_factory):
291+
"""
292+
Download the English histogram from S3 just once
293+
and return the directory path that contains it.
294+
295+
Uses tmp_path_factory → one persistent temp-dir for the whole session.
296+
"""
297+
s3 = boto3.client('s3',
298+
aws_access_key_id=os.getenv("AWS_ACCESS_KEY"),
299+
aws_secret_access_key=os.getenv("AWS_SECRET_KEY")
300+
)
301+
302+
bucket = "sdp-test-data"
303+
key = "test_data/test_processors/CharacterHistogramLangValidator/histograms/en"
304+
305+
tmp_dir = tmp_path_factory.mktemp("char_hists")
306+
local_path = tmp_dir / "en"
307+
308+
if not local_path.exists():
309+
try:
310+
s3.download_file(bucket, key, str(local_path))
311+
except ClientError as e:
312+
code = e.response.get("Error", {}).get("Code", "")
313+
pytest.skip(f"Cannot download s3://{bucket}/{key} ({code}).")
314+
315+
assert local_path.exists(), "Histogram file was not downloaded"
316+
return str(tmp_dir)
317+
318+
@pytest.mark.parametrize(
319+
"text,expected",
320+
[
321+
# Plain English sentence; all characters expected in 'en' histogram -> ratio 1.0
322+
("Hello, how are you today?", 1.0),
323+
# # Chinese characters; none expected in 'en' histogram -> ratio 0.0
324+
("今天天气很好,我们去公园吧。", 0.0),
325+
# Symbols + digits; only digits 1..5 expected in 'en' histogram -> 5 matches out of 17 chars
326+
("@#$%^&*()_+=12345", 5 / 17), # 0.29411764705882354
327+
# French sentence with one accented char 'é' not in 'en' histogram -> 23 matches out of 24 chars
328+
("C'est une belle journée.", 23 / 24), # 0.9583333333333334
329+
],
330+
)
331+
def test_character_hist_validator(text, expected, en_hist_dir):
332+
processor = CharacterHistogramLangValidator(
333+
text_field="text",
334+
lang="en",
335+
cache_dir=en_hist_dir,
336+
output_manifest_file=None,
337+
)
338+
processor.prepare()
339+
340+
entry = {"text": text}
341+
result_entry = processor.process_dataset_entry(entry)[0].data
342+
343+
assert result_entry[processor.output_score_field] == pytest.approx(expected, rel=1e-12)
344+
285345
@pytest.mark.parametrize("test_class,class_kwargs,test_input,expected_output", test_params_list, ids=str)
286346
def test_data_to_data(test_class, class_kwargs, test_input, expected_output):
287347
processor = test_class(**class_kwargs, output_manifest_file=None)
288348
result = [entry.data for entry in processor.process_dataset_entry(test_input)]
289349

290-
assert result == expected_output
350+
assert result == expected_output

0 commit comments

Comments
 (0)