|
1 | 1 | import argparse |
2 | 2 |
|
| 3 | +import io |
3 | 4 | import os |
4 | 5 | import torch |
5 | 6 | import evaluate |
@@ -51,12 +52,33 @@ def download_audio_files(batch): |
51 | 52 | durations = [] |
52 | 53 |
|
53 | 54 | for id, sample in zip(batch["id"], batch["audio"]): |
| 55 | + |
| 56 | + # first step added here to make ID and wav filenames unique |
| 57 | + # several datasets like earnings22 have a hierarchical structure |
| 58 | + # for eg. earnings22/test/4432298/281.wav, earnings22/test/4450488/281.wav |
| 59 | + # lhotse uses the filename (281.wav) here as unique ID to create and name cuts |
| 60 | + # ref: https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/collation.py#L186 |
| 61 | + id = id.replace('/', '_').removesuffix('.wav') |
| 62 | + |
54 | 63 | audio_path = os.path.join(CACHE_DIR, f"{id}.wav") |
| 64 | + |
| 65 | + if "array" in sample: |
| 66 | + audio_array = np.float32(sample["array"]) |
| 67 | + sample_rate = 16000 |
| 68 | + |
| 69 | + elif "bytes" in sample: # added to be compatible with latest datasets library (3.x.x) that produces byte stream |
| 70 | + with io.BytesIO(sample["bytes"]) as audio_file: |
| 71 | + audio_array, sample_rate = soundfile.read(audio_file, dtype="float32") |
| 72 | + |
| 73 | + else: |
| 74 | + raise ValueError("Sample must have either 'array' or 'bytes' key") |
| 75 | + |
55 | 76 | if not os.path.exists(audio_path): |
56 | 77 | os.makedirs(os.path.dirname(audio_path), exist_ok=True) |
57 | | - soundfile.write(audio_path, np.float32(sample["array"]), 16_000) |
| 78 | + soundfile.write(audio_path, audio_array, sample_rate) |
| 79 | + |
58 | 80 | audio_paths.append(audio_path) |
59 | | - durations.append(len(sample["array"]) / 16_000) |
| 81 | + durations.append(len(audio_array) / sample_rate) |
60 | 82 |
|
61 | 83 |
|
62 | 84 | batch["references"] = batch["norm_text"] |
@@ -118,7 +140,7 @@ def download_audio_files(batch): |
118 | 140 | # normalize transcriptions with English normalizer |
119 | 141 | if isinstance(transcriptions, tuple) and len(transcriptions) == 2: |
120 | 142 | transcriptions = transcriptions[0] |
121 | | - predictions = [data_utils.normalizer(pred) for pred in transcriptions] |
| 143 | + predictions = [data_utils.normalizer(pred.text) for pred in transcriptions] |
122 | 144 |
|
123 | 145 | avg_time = total_time / len(all_data["audio_filepaths"]) |
124 | 146 |
|
|
0 commit comments