Skip to content

Commit 14e0ce3

Browse files
committed
Added CharacterHistograms processor
Signed-off-by: Sasha Meister <sasha.meister.work@gmail.com>
1 parent fe0927d commit 14e0ce3

File tree

2 files changed

+159
-5
lines changed

2 files changed

+159
-5
lines changed

dataset_configs/multilingual/yodas2/config.yaml

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ filters:
1313
translation:
1414
source_lang: English
1515
target_lang: Italian
16+
filters:
17+
max_len_diff_ratio: 4
18+
max_hist_token_ratio: 0.8
1619

1720
processors:
1821
- _target_: sdp.processors.datasets.yodas2.ListYodas2Data
@@ -232,19 +235,74 @@ processors:
232235
max_length: 512
233236
tokenize: False
234237
add_generation_prompt: True
238+
239+
- _target_: sdp.processors.CountNumWords
240+
output_manifest_file: ${workspace_dir}/${filters.source_lang}/manifest_28.json
241+
text_key: pred_text
242+
num_words_key: num_words_src
243+
244+
- _target_: sdp.processors.PreserveByValue
245+
output_manifest_file: ${workspace_dir}/${filters.source_lang}/manifest_29.json
246+
input_value_key: num_words_src
247+
operator: gt
248+
target_value: 1
249+
250+
- _target_: sdp.processors.CountNumWords
251+
output_manifest_file: ${workspace_dir}/${filters.source_lang}/manifest_30.json
252+
text_key: generation
253+
num_words_key: num_words_tgt
254+
255+
- _target_: sdp.processors.PreserveByValue
256+
output_manifest_file: ${workspace_dir}/${filters.source_lang}/manifest_31.json
257+
input_value_key: num_words_tgt
258+
operator: gt
259+
target_value: 1
260+
261+
- _target_: sdp.processors.LambdaExpression
262+
output_manifest_file: ${workspace_dir}/${filters.source_lang}/manifest_32.json
263+
new_field: 'len_diff_ratio'
264+
expression: max(entry.num_words_src - entry.num_words_tgt, entry.num_words_tgt - entry.num_words_src)
265+
266+
- _target_: sdp.processors.PreserveByValue
267+
output_manifest_file: ${workspace_dir}/${filters.source_lang}/manifest_33.json
268+
input_value_key: len_diff_ratio
269+
operator: lt
270+
target_value: ${translation.filters.max_len_diff_ratio}
271+
272+
- _target_: sdp.processors.CharacterHistograms
273+
output_manifest_file: ${workspace_dir}/${filters.source_lang}/manifest_34.json
274+
text_field: pred_text
275+
lang: ${filters.source_lang}
276+
output_score_field: hist_token_ratio_pred_text
277+
cache_dir: ""
278+
279+
- _target_: sdp.processors.PreserveByValue
280+
output_manifest_file: ${workspace_dir}/${filters.source_lang}/manifest_35.json
281+
input_value_key: hist_token_ratio_pred_text
282+
operator: lt
283+
target_value: ${translation.filters.max_hist_token_ratio}
284+
285+
- _target_: sdp.processors.CharacterHistograms
286+
output_manifest_file: ${workspace_dir}/${filters.source_lang}/manifest_34.json
287+
text_field: generation
288+
lang: it #${filters.source_lang}
289+
output_score_field: hist_token_ratio_generation
290+
cache_dir: ""
291+
292+
- _target_: sdp.processors.PreserveByValue
293+
output_manifest_file: ${workspace_dir}/${filters.source_lang}/manifest_35.json
294+
input_value_key: hist_token_ratio_generation
295+
operator: lt
296+
target_value: ${translation.filters.max_hist_token_ratio}
235297

236298
- _target_: sdp.processors.CometoidWMTQualityEstimation
237299
output_manifest_file: ${workspace_dir}/${filters.source_lang}/manifest_28.json
238300
source_text_field: pred_text #source
239301
target_text_field: generation #target
240302
model_name_or_path: cometoid-wmt23
241-
device_type: gou
303+
device_type: gpu
242304
num_devices: 4
243305
chunksize: 10
244-
245-
246-
247-
248306

249307

250308

sdp/processors/metrics/text.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
# limitations under the License.
1414

1515
import re
16+
import os
17+
import tempfile
18+
import shutil
19+
import requests
20+
import wget
21+
import tarfile
22+
from tqdm import tqdm
1623

1724
from sdp.logging import logger
1825
from sdp.processors.base_processor import BaseParallelProcessor, DataEntry
@@ -52,4 +59,93 @@ def process_dataset_entry(self, data_entry):
5259
words = cleaned_string.split()
5360
num_words = len(words)
5461
data_entry[self.num_words_key] = num_words
62+
return [DataEntry(data=data_entry)]
63+
64+
65+
class CharacterHistograms(BaseParallelProcessor):
66+
def __init__(self,
67+
text_field: str,
68+
lang_field: str = None,
69+
lang: str = None,
70+
threshold: float = 0.8,
71+
cache_dir: str = None,
72+
threshold_char: str = "]",
73+
output_score_field: str = "hist_token_ratio",
74+
**kwargs):
75+
super().__init__(**kwargs)
76+
self.text_field = text_field
77+
78+
if lang_field is None and lang is None:
79+
raise ValueError("One of the arguments `lang` or `lang_field` must be provided.")
80+
81+
if lang_field is not None and lang is not None:
82+
raise ValueError(
83+
f"Both `lang` ({lang}) and `lang_field` ({lang_field}) are provided, which makes the source of language ambiguous. Please provide only one of them."
84+
)
85+
86+
self.text_field = text_field
87+
self.lang_field = lang_field
88+
self.lang = lang
89+
self.threshold = threshold
90+
self.cache_dir = cache_dir
91+
self.threshold_char = threshold_char
92+
self.output_score_field = output_score_field
93+
self.histograms = dict()
94+
95+
def _read_hist(self, lang: str):
96+
hist_file = os.path.join(self.cache_dir, lang)
97+
chars = []
98+
with open(hist_file) as hist:
99+
for line in hist:
100+
char = line[0]
101+
chars.append(char)
102+
if char == self.threshold_char:
103+
break
104+
self.histograms[lang] =set(chars)
105+
106+
def prepare(self):
107+
if self.cache_dir is None:
108+
self.cache_dir = tempfile.mkdtemp()
109+
110+
os.makedirs(self.cache_dir, exist_ok=True)
111+
112+
if not os.path.exists(self.cache_dir):
113+
logger.info(f'Downloading histograms to {self.cache_dir}')
114+
histograms_url = 'https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz'
115+
response = requests.get(histograms_url)
116+
117+
if response.status_code != 200:
118+
raise requests.exceptions.RequestException(
119+
f"Failed to download histogram file. Status code: {response.status_code}"
120+
)
121+
122+
histograms_tarfile = wget.download(histograms_url, out=self.cache_dir)
123+
with tarfile.open(histograms_tarfile, "r:gz") as tar:
124+
tar.extractall(path=self.cache_dir)
125+
126+
self.cache_dir = os.path.join(self.cache_dir, "checkpoint/edunov/cc60_multilingual/clean_hists")
127+
logger.info(f'Histograms are downloaded.')
128+
129+
logger.info(f'Reading histograms')
130+
available_langs = os.listdir(self.cache_dir)
131+
if self.lang is not None:
132+
if self.lang in available_langs:
133+
self._read_hist(self.lang)
134+
else:
135+
raise ValueError(f"Invalid value for `lang`: {self.lang}. Please provide one of the following: {available_langs}")
136+
logger.info(f'Histogram for `{self.lang}` has been read.')
137+
else:
138+
for lang in tqdm(available_langs):
139+
self._read_hist(lang)
140+
logger.info(f'Histograms have been read.')
141+
142+
def process_dataset_entry(self, data_entry):
143+
lang = self.lang if self.lang is not None else data_entry[self.lang_field]
144+
if lang not in self.histograms:
145+
raise ValueError(f'lang `{lang} is not supported.')
146+
147+
text = data_entry[self.text_field].strip()
148+
cnt = len([c for c in text if c in self.histograms[lang]])
149+
token_ratio = 1 if cnt / len(text) > self._threshold else 0
150+
data_entry[self.output_score_field] = token_ratio
55151
return [DataEntry(data=data_entry)]

0 commit comments

Comments
 (0)