Skip to content

Commit 1089cb2

Browse files
committed
fixes for Canary-1B-Flash
Signed-off-by: Kunal Dhawan <[email protected]>
1 parent f34302c commit 1089cb2

File tree

4 files changed

+18
-102
lines changed

4 files changed

+18
-102
lines changed

nemo_asr/run_canary.sh

100644100755
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
export PYTHONPATH="..":$PYTHONPATH
44

5-
MODEL_IDs=("nvidia/canary-1b")
5+
MODEL_IDs=("nvidia/canary-1b-flash") # options: "nvidia/canary-1b" "nvidia/canary-1b-flash"
66
BATCH_SIZE=64
77
DEVICE_ID=0
88

@@ -11,7 +11,6 @@ num_models=${#MODEL_IDs[@]}
1111
for (( i=0; i<${num_models}; i++ ));
1212
do
1313
MODEL_ID=${MODEL_IDs[$i]}
14-
1514

1615
python run_eval.py \
1716
--model_id=${MODEL_ID} \

nemo_asr/run_canary_flash.sh

Lines changed: 0 additions & 94 deletions
This file was deleted.

nemo_asr/run_eval.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22

3+
import io
34
import os
45
import torch
56
import evaluate
@@ -50,15 +51,26 @@ def download_audio_files(batch):
5051
audio_paths = []
5152
durations = []
5253

53-
# import ipdb; ipdb.set_trace()
54-
5554
for id, sample in zip(batch["id"], batch["audio"]):
5655
audio_path = os.path.join(CACHE_DIR, f"{id}.wav")
56+
57+
if "array" in sample:
58+
audio_array = np.float32(sample["array"])
59+
sample_rate = 16000
60+
61+
elif "bytes" in sample: # added to be compatible with latest datasets library (3.x.x) that produces byte stream
62+
with io.BytesIO(sample["bytes"]) as audio_file:
63+
audio_array, sample_rate = soundfile.read(audio_file, dtype="float32")
64+
65+
else:
66+
raise ValueError("Sample must have either 'array' or 'bytes' key")
67+
5768
if not os.path.exists(audio_path):
5869
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
59-
soundfile.write(audio_path, np.float32(sample["array"]), 16_000)
70+
soundfile.write(audio_path, audio_array, sample_rate)
71+
6072
audio_paths.append(audio_path)
61-
durations.append(len(sample["array"]) / 16_000)
73+
durations.append(len(audio_array) / sample_rate)
6274

6375

6476
batch["references"] = batch["norm_text"]

requirements/requirements_nemo.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ tqdm
33
soundfile
44
librosa
55
IPython # Workaround for https://github.com/NVIDIA/NeMo/pull/9890#discussion_r1701028427
6-
cuda-python>=12.4 # Used for fast TDT and RNN-T inference
7-
datasets <= 2.21.0
6+
cuda-python>=12.4 # Used for fast TDT and RNN-T inference

0 commit comments

Comments
 (0)