Skip to content

Commit afba94c

Browse files
authored
Merge pull request #55 from KunalDhawan/add_canary_flash
2 parents 9a9c09f + 2de6ec6 commit afba94c

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

nemo_asr/run_canary.sh

100644100755
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22

33
export PYTHONPATH="..":$PYTHONPATH
44

5-
MODEL_IDs=("nvidia/canary-1b")
6-
BATCH_SIZE=64
5+
MODEL_IDs=("nvidia/canary-1b-flash") # options: "nvidia/canary-1b" "nvidia/canary-1b-flash" "nvidia/canary-180m-flash"
6+
BATCH_SIZE=128
77
DEVICE_ID=0
88

99
num_models=${#MODEL_IDs[@]}
1010

1111
for (( i=0; i<${num_models}; i++ ));
1212
do
1313
MODEL_ID=${MODEL_IDs[$i]}
14-
1514

1615
python run_eval.py \
1716
--model_id=${MODEL_ID} \

nemo_asr/run_eval.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22

3+
import io
34
import os
45
import torch
56
import evaluate
@@ -51,12 +52,33 @@ def download_audio_files(batch):
5152
durations = []
5253

5354
for id, sample in zip(batch["id"], batch["audio"]):
55+
56+
# first step added here to make ID and wav filenames unique
57+
# several datasets like earnings22 have a hierarchical structure
58+
# for eg. earnings22/test/4432298/281.wav, earnings22/test/4450488/281.wav
59+
# lhotse uses the filename (281.wav) here as unique ID to create and name cuts
60+
# ref: https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/collation.py#L186
61+
id = id.replace('/', '_').removesuffix('.wav')
62+
5463
audio_path = os.path.join(CACHE_DIR, f"{id}.wav")
64+
65+
if "array" in sample:
66+
audio_array = np.float32(sample["array"])
67+
sample_rate = 16000
68+
69+
elif "bytes" in sample: # added to be compatible with latest datasets library (3.x.x) that produces byte stream
70+
with io.BytesIO(sample["bytes"]) as audio_file:
71+
audio_array, sample_rate = soundfile.read(audio_file, dtype="float32")
72+
73+
else:
74+
raise ValueError("Sample must have either 'array' or 'bytes' key")
75+
5576
if not os.path.exists(audio_path):
5677
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
57-
soundfile.write(audio_path, np.float32(sample["array"]), 16_000)
78+
soundfile.write(audio_path, audio_array, sample_rate)
79+
5880
audio_paths.append(audio_path)
59-
durations.append(len(sample["array"]) / 16_000)
81+
durations.append(len(audio_array) / sample_rate)
6082

6183

6284
batch["references"] = batch["norm_text"]
@@ -118,7 +140,7 @@ def download_audio_files(batch):
118140
# normalize transcriptions with English normalizer
119141
if isinstance(transcriptions, tuple) and len(transcriptions) == 2:
120142
transcriptions = transcriptions[0]
121-
predictions = [data_utils.normalizer(pred) for pred in transcriptions]
143+
predictions = [data_utils.normalizer(pred.text) for pred in transcriptions]
122144

123145
avg_time = total_time / len(all_data["audio_filepaths"])
124146

requirements/requirements_nemo.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
git+https://github.com/NVIDIA/NeMo.git@d0efff087613ea2584e215969f289fed17414d8b#egg=nemo_toolkit[all] # This commit hash is a recent version of main at the time of testing.
1+
git+https://github.com/NVIDIA/NeMo.git@208e0da28e2ada8da84d8f7ddff8623efe1ff01c#egg=nemo_toolkit[asr] # This commit hash is a recent version of main at the time of testing.
22
tqdm
33
soundfile
44
librosa
55
IPython # Workaround for https://github.com/NVIDIA/NeMo/pull/9890#discussion_r1701028427
6-
cuda-python>=12.4 # Used for fast TDT and RNN-T inference
6+
cuda-python>=12.4 # Used for fast TDT and RNN-T inference

0 commit comments

Comments
 (0)