Skip to content

Commit 2796f6d

Browse files
committed
Moved Granary to cfg_end_to_end tests
Signed-off-by: Sasha Meister <ameister@nvidia.com>
1 parent bcc3bf0 commit 2796f6d

File tree

5 files changed

+74
-136
lines changed

5 files changed

+74
-136
lines changed

dataset_configs/multilingual/granary/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ and is published as [nvidia/Granary](https://huggingface.co/datasets/nvidia/Gran
2222
2323
> Note — GPU required
2424
>
25-
> All Whisper, vLLM, FastText and Comet-QE stages expect at least one CUDA-capable GPU. Multi-GPU nodes are auto-detected when `num_devices: -1` (default) is used.
25+
> All Whisper, vLLM and Comet-QE stages expect at least one CUDA-capable GPU. Multi-GPU nodes are auto-detected when `num_devices: -1` (default) is used.
2626
2727
### Software prerequisites
2828

dataset_configs/multilingual/granary/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ documentation: |
3434
3535
.. note::
3636
37-
**GPU required.** All Whisper, vLLM, FastText and Comet-QE stages expect at
37+
**GPU required.** All Whisper, vLLM and Comet-QE stages expect at
3838
least one CUDA-capable GPU. Multi-GPU nodes are auto-detected when
3939
``num_devices: -1`` (default) is used.
4040

tests/test_cfg_end_to_end_tests.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tarfile
1818
import logging
1919
from functools import partial
20+
from itertools import chain
2021
from pathlib import Path
2122
from typing import Callable, List, Tuple
2223
from unittest import mock
@@ -135,6 +136,46 @@ def data_check_fn_armenian_toloka_pipeline_get_final_res(raw_data_dir: str) -> N
135136
]
136137
)
137138

139+
def data_check_fn_granary(raw_data_dir: str) -> None:
140+
def create_init_manifest(audio_files, raw_data_dir):
141+
with open(os.path.join(raw_data_dir, "input_manifest.json"), "wt", encoding="utf8") as f:
142+
for audio_file in audio_files:
143+
line = json.dumps({"source_audio_filepath": os.path.join(raw_data_dir, audio_file)})
144+
f.write(line + "\n")
145+
146+
def to_abs_audio_paths(manifest_filepath, raw_data_dir):
147+
with open(manifest_filepath, "rt", encoding="utf8") as f:
148+
lines = f.readlines()
149+
150+
with open(manifest_filepath, "wt", encoding="utf8") as f:
151+
for line in lines:
152+
sample = json.loads(line)
153+
for audio_field in ["source_audio_filepath", "audio_filepath"]:
154+
if audio_field in sample:
155+
sample[audio_field] = str(Path(raw_data_dir) / sample[audio_field])
156+
f.write(json.dumps(sample) + "\n")
157+
158+
audio_files = [
159+
"audio/zCW0Pa0BI4Q.wav", "audio/zHWk3Ae7qJ0.wav", "audio/zHtFdl5K8qg.wav",
160+
"audio/zCW9rGbaF4E.wav", "audio/zG3RpHaMzkQ.wav"
161+
]
162+
163+
manifest_files = [
164+
"reference_manifest.json", "manifest_03.json", "manifest_06.json",
165+
"manifest_14.json", "manifest_21.json", "manifest_26.json",
166+
"manifest_41.json",
167+
]
168+
169+
cache_files = ["cache/histograms/en", "cache/histograms/it", "cache/models/lid.176.bin"]
170+
171+
for file in chain(audio_files, manifest_files, cache_files):
172+
if not (Path(raw_data_dir) / file).exists():
173+
raise ValueError(f"No such file {str(Path(raw_data_dir) / file)}")
174+
175+
create_init_manifest(audio_files, raw_data_dir)
176+
for manifest_file in manifest_files:
177+
to_abs_audio_paths(Path(raw_data_dir) / manifest_file, raw_data_dir)
178+
138179
def get_test_cases() -> List[Tuple[str, Callable]]:
139180
return [
140181
TestCase(
@@ -281,6 +322,12 @@ def get_test_cases() -> List[Tuple[str, Callable]]:
281322
data_check_fn=partial(data_check_fn_generic, file_name="manifest_22khz.json"),
282323
reference_manifest_filename="test_data_reference_bandwidth.json",
283324
),
325+
TestCase(
326+
config_path=f"{DATASET_CONFIGS_ROOT}/multilingual/granary/config.yaml",
327+
data_check_fn=data_check_fn_granary,
328+
reference_manifest_filename="reference_manifest.json",
329+
fields_to_ignore=['audio_filepath'],
330+
),
284331
]
285332

286333
def get_test_names():
@@ -350,7 +397,6 @@ def setup_data(request):
350397
if os.getenv("CLEAN_UP_DATA_DIR", "0") != "0":
351398
shutil.rmtree(data_dir)
352399

353-
354400
def test_data_availability(setup_data):
355401
_, data_check_fn, reference_manifest_filename, data_dir, fields_to_ignore, _ = setup_data
356402
try:
@@ -380,7 +426,6 @@ def test_configs(setup_data, tmp_path):
380426
cfg.data_split = cfg.get("data_split", "train")
381427
cfg.processors[0].raw_data_dir = data_dir.as_posix()
382428

383-
384429
if "already_downloaded" in cfg["processors"][0]:
385430
cfg["processors"][0]["already_downloaded"] = True
386431

@@ -404,6 +449,30 @@ def test_configs(setup_data, tmp_path):
404449
if "english/hifitts2/config_bandwidth" in config_path:
405450
cfg.processors[0].audio_dir = (data_dir / "audio_22khz").as_posix()
406451
cfg.processors[0].input_manifest_file = (data_dir / "manifest_22khz.json").as_posix()
452+
453+
if "multilingual/granary/config" in config_path:
454+
cfg.input_manifest_file = data_dir / "input_manifest.json"
455+
cfg.output_dir = data_dir
456+
cfg.sdp_dir = Path(__file__).parents[1]
457+
cfg.final_manifest = cfg.processors[-1].output_manifest_file
458+
459+
# Disable processors that uses GPU
460+
processors_to_disable = [
461+
3, 6, 14, # FasterWhisperInference
462+
21, 26, # vLLMInference
463+
41, # CometoidWMTQualityEstimation
464+
]
465+
466+
for processor_idx in processors_to_disable:
467+
processor_id = str(processor_idx).zfill(2)
468+
cfg.processors[processor_idx].should_run = False
469+
cfg.processors[processor_idx + 1].input_manifest_file = os.path.join(data_dir, f"manifest_{processor_id}.json")
470+
471+
# Set cache directories
472+
cfg.processors[33].cache_dir = os.path.join(data_dir, "cache", "histograms")
473+
cfg.processors[34].cache_dir = os.path.join(data_dir, "cache", "histograms")
474+
cfg.processors[37].cache_dir = os.path.join(data_dir, "cache", "models")
475+
cfg.processors[38].cache_dir = os.path.join(data_dir, "cache", "models")
407476

408477
run_processors(cfg)
409478
# additionally, let's test that final generated manifest matches the

tests/test_data_to_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def en_hist_dir(tmp_path_factory):
319319
s3.download_file(bucket, key, str(local_path))
320320
except ClientError as e:
321321
code = e.response.get("Error", {}).get("Code", "")
322-
pytest.skip(f"Cannot download s3://{bucket}/{key} ({code}).")
322+
raise FileNotFoundError(f"Cannot download s3://{bucket}/{key} ({code}).")
323323

324324
assert local_path.exists(), "Histogram file was not downloaded"
325325
return str(tmp_dir)

tests/test_granary_pipeline_end_to_end.py

Lines changed: 0 additions & 131 deletions
This file was deleted.

0 commit comments

Comments
 (0)