Skip to content

Commit c9e72fd

Browse files
committed
Added HfHubDownloadYodas2Data, HfHubDownload, GetGranarysYodas2
Signed-off-by: Sasha Meister <[email protected]>
1 parent 54a0af5 commit c9e72fd

File tree

4 files changed

+153
-29
lines changed

4 files changed

+153
-29
lines changed

sdp/processors/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from sdp.processors.manage_files.convert_to_tarred_audio_dataset import ConvertToTarredAudioDataset
8282

8383
from sdp.processors.huggingface.create_initial_manifest import CreateInitialManifestHuggingFace
84-
from sdp.processors.huggingface.huggingface_hub import ListRepoFiles, SnapshotDownload
84+
from sdp.processors.huggingface.huggingface_hub import ListRepoFiles, SnapshotDownload, HfHubDownload
8585

8686
from sdp.processors.inference.asr.nemo.asr_inference import ASRInference
8787
from sdp.processors.inference.asr.transformers.speech_recognition import ASRTransformers
@@ -159,6 +159,8 @@
159159
)
160160
from sdp.processors.datasets.yodas2.create_initial_manifest import(
161161
ListYodas2Data,
162-
DownloadYodas2Data,
162+
SnapshotDownloadYodas2Data,
163+
HfHubDownloadYodas2Data,
163164
CreateInitialManifestYodas2,
164-
)
165+
)
166+
from sdp.processors.datasets.yodas2.granary import GetGranarysYodas2

sdp/processors/datasets/yodas2/create_initial_manifest.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import importlib.util
2121

2222
from sdp.processors import ListToEntries
23-
from sdp.processors.huggingface.huggingface_hub import ListRepoFiles, SnapshotDownload
23+
from sdp.processors.huggingface.huggingface_hub import ListRepoFiles, SnapshotDownload, HfHubDownload
2424
from sdp.logging import logger
2525

2626

@@ -137,7 +137,7 @@ def process(self):
137137
logger.info("Metadata successfully saved!")
138138

139139

140-
class DownloadYodas2Data(SnapshotDownload):
140+
class SnapshotDownloadYodas2Data(SnapshotDownload):
141141
"""
142142
A specialized processor for downloading the YODAS2 dataset from Hugging Face
143143
and updating the input manifest with local file paths to the downloaded files.
@@ -194,7 +194,12 @@ class DownloadYodas2Data(SnapshotDownload):
194194

195195
def __init__(self, **kwargs):
196196
# Hardcoded to download the espnet/yodas2 dataset from Hugging Face
197-
super().__init__(repo_id="espnet/yodas2", repo_type="dataset", **kwargs)
197+
if not 'snapshot_download_args' in kwargs:
198+
kwargs['snapshot_download_args'] = dict()
199+
kwargs['snapshot_download_args']['repo_id'] = 'espnet/yodas2'
200+
kwargs['snapshot_download_args']['repo_type'] = 'dataset'
201+
202+
super().__init__(**kwargs)
198203

199204
def write_output_manifest_file(self):
200205
"""
@@ -271,6 +276,18 @@ def process(self):
271276
self.write_output_manifest_file()
272277

273278

279+
class HfHubDownloadYodas2Data(HfHubDownload):
280+
def __init__(self, filename_field: str = 'audio_key', output_filepath_field = 'local_audio', **kwargs):
281+
if not 'hf_hub_download_args' in kwargs:
282+
kwargs['hf_hub_download_args'] = dict()
283+
kwargs['hf_hub_download_args']['repo_id'] = 'espnet/yodas2'
284+
kwargs['hf_hub_download_args']['repo_type'] = 'dataset'
285+
286+
super().__init__(filename_field = filename_field, output_filepath_field = output_filepath_field, **kwargs)
287+
288+
def process(self):
289+
super().process()
290+
274291
class CreateInitialManifestYodas2(ListToEntries):
275292
"""
276293
A dataset processor specialized for the YODAS2 dataset.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import os
2+
import json
3+
from glob import glob
4+
from tqdm import tqdm
5+
import tempfile
6+
7+
from sdp.processors.huggingface.huggingface_hub import SnapshotDownload
8+
from sdp.logging import logger
9+
10+
class GetGranarysYodas2(SnapshotDownload):
11+
AVAILABLE_LANGS = ["bg", "cs", "da", "de", "el",
12+
"en", "es", "et", "fi", "fr",
13+
"hr", "hu", "it", "lt", "lv",
14+
"nl", "pl", "pt", "ro", "ru",
15+
"sk", "sv", "uk"]
16+
17+
def __init__(self, lang: str, translation: bool = False, **kwargs):
18+
super().__init__(repo_id="YODASEnj/YDS", repo_type="dataset", **kwargs)
19+
if lang not in self.AVAILABLE_LANGS:
20+
raise ValueError("")
21+
self.lang = lang
22+
23+
self.translation = translation
24+
if self.lang == "en" and self.translation:
25+
logger.warning(f'There are no translations for `en` language.')
26+
self.translation = False
27+
28+
def process(self):
29+
os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok = True)
30+
with open(self.output_manifest_file, 'w', encoding='utf8') as fout:
31+
pattern = f"{self.lang}/{self.lang}*.json"
32+
if self.translation:
33+
pattern = f"Translation/{self.lang}_/{self.lang}*.jsonl"
34+
35+
self.snapshot_download_kwargs['allow_patterns'] = pattern
36+
with tempfile.TemporaryDirectory() as tmp_dir:
37+
self.snapshot_download_kwargs["local_dir"] = tmp_dir
38+
self.download()
39+
40+
for manifest_filepath in sorted(glob(f"{tmp_dir}/{pattern}")):
41+
with open(manifest_filepath, 'r', encoding='utf8') as fin:
42+
for line in tqdm(fin, desc = f'Processing {os.path.basename(manifest_filepath)}'):
43+
sample = json.loads(line)
44+
new_sample = dict(source_lang = self.lang,
45+
target_lang = self.lang,
46+
yodas_id = sample['wav_id'],
47+
offset = sample['start_time'],
48+
duration = sample['duration'],
49+
text = sample['text'],
50+
answer = sample['text'],
51+
decodercontext = "",
52+
emotion = "<|emo:undefined|>",
53+
pnc = "pnc",
54+
itn = "itn",
55+
timestamp = "notimestamp",
56+
diarize = "nodiarize")
57+
58+
if self.translation:
59+
new_sample['target_lang'] = "en"
60+
new_sample['answer'] = sample['translation_en']
61+
62+
fout.writelines(json.dumps(new_sample) + '\n')

sdp/processors/huggingface/huggingface_hub.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,16 @@
1313
# limitations under the License.
1414

1515
import json
16+
import os
17+
from typing import Dict
1618

17-
from sdp.processors.base_processor import BaseProcessor
19+
from tqdm.contrib.concurrent import process_map
1820

21+
from sdp.processors.base_processor import BaseProcessor, BaseParallelProcessor
22+
23+
def _hf_hub_download(kwargs):
24+
from huggingface_hub import hf_hub_download
25+
return hf_hub_download(**kwargs)
1926

2027
class ListRepoFiles(BaseProcessor):
2128
"""
@@ -93,34 +100,70 @@ class SnapshotDownload(BaseProcessor):
93100

94101
def __init__(
95102
self,
96-
output_manifest_file: str,
97-
input_manifest_file: str = None,
98-
**snapshot_download_kwargs,
103+
output_filepath_field: str = "downloaded",
104+
snapshot_download_args: dict = {},
105+
**kwargs,
99106
):
100-
super().__init__(
101-
output_manifest_file=output_manifest_file,
102-
input_manifest_file=input_manifest_file,
103-
)
104-
self.snapshot_download_kwargs = snapshot_download_kwargs
107+
super().__init__(**kwargs)
108+
self.output_filepath_field = output_filepath_field
109+
self.snapshot_download_args = snapshot_download_args
105110

106-
def download(self):
111+
def process(self):
107112
"""
108-
Download the repository snapshot to a local folder.
113+
Main processing entrypoint: download repo and write path to manifest.
109114
"""
110115
from huggingface_hub import snapshot_download
111116

112-
self.local_dir = snapshot_download(**self.snapshot_download_kwargs)
113-
114-
def write_output_manifest_file(self):
115-
"""
116-
Write the path of the downloaded snapshot folder to the output manifest.
117-
"""
117+
self.local_dir = snapshot_download(**self.snapshot_download_args)
118+
118119
with open(self.output_manifest_file, 'w', encoding='utf8') as fout:
119-
fout.writelines(json.dumps({"destination_dir": self.local_dir}))
120+
fout.writelines(json.dumps({self.output_filepath_field : self.local_dir}))
121+
122+
123+
class HfHubDownload(BaseParallelProcessor):
124+
def __init__(
125+
self,
126+
filename_field: str,
127+
output_filepath_field: str = "downloaded",
128+
hf_hub_download_args: Dict = {},
129+
**kwargs
130+
):
131+
super().__init__(**kwargs)
132+
self.filename_field = filename_field
133+
self.output_filepath_field = output_filepath_field
134+
self.hf_hub_download_args = hf_hub_download_args
120135

121136
def process(self):
122-
"""
123-
Main processing entrypoint: download repo and write path to manifest.
124-
"""
125-
self.download()
126-
self.write_output_manifest_file()
137+
self.prepare()
138+
os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True)
139+
140+
with open(self.output_manifest_file, "wt", encoding="utf8") as fout:
141+
for manifest_chunk in self._chunk_manifest():
142+
# Подготовим список задач
143+
download_tasks = [
144+
{
145+
**self.hf_hub_download_args,
146+
"filename": entry[self.filename_field]
147+
}
148+
for entry in manifest_chunk
149+
]
150+
151+
# Параллельная загрузка с учётом max_workers и chunksize
152+
results = process_map(
153+
_hf_hub_download,
154+
download_tasks,
155+
max_workers=self.max_workers,
156+
chunksize=self.chunksize,
157+
)
158+
159+
# Сопоставим обратно результаты с входными entry
160+
for entry, local_path in zip(manifest_chunk, results):
161+
entry[self.output_filepath_field] = local_path
162+
json.dump(entry, fout, ensure_ascii=False)
163+
fout.write("\n")
164+
self.number_of_entries += 1
165+
166+
self.finalize(self.test_cases)
167+
168+
def process_dataset_entry(self, data_entry):
169+
pass

0 commit comments

Comments
 (0)