Skip to content

Commit 48ac13a

Browse files
author
Nithin Rao Koluguri
committed
update nemo inference and include RTFx
Signed-off-by: Nithin Rao Koluguri <nithinraok>
1 parent 55c3d1d commit 48ac13a

File tree

5 files changed

+177
-125
lines changed

5 files changed

+177
-125
lines changed

nemo_asr/run_canary.sh

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
export PYTHONPATH="..":$PYTHONPATH
44

55
MODEL_IDs=("nvidia/canary-1b")
6-
PNC=False
76
BATCH_SIZE=64
87
DEVICE_ID=0
98

@@ -20,7 +19,6 @@ do
2019
--dataset="ami" \
2120
--split="test" \
2221
--device=${DEVICE_ID} \
23-
--pnc=${PNC} \
2422
--batch_size=${BATCH_SIZE} \
2523
--max_eval_samples=-1
2624

@@ -30,7 +28,6 @@ do
3028
--dataset="earnings22" \
3129
--split="test" \
3230
--device=${DEVICE_ID} \
33-
--pnc=${PNC} \
3431
--batch_size=${BATCH_SIZE} \
3532
--max_eval_samples=-1
3633

@@ -40,7 +37,6 @@ do
4037
--dataset="gigaspeech" \
4138
--split="test" \
4239
--device=${DEVICE_ID} \
43-
--pnc=${PNC} \
4440
--batch_size=${BATCH_SIZE} \
4541
--max_eval_samples=-1
4642

@@ -50,7 +46,6 @@ do
5046
--dataset="librispeech" \
5147
--split="test.clean" \
5248
--device=${DEVICE_ID} \
53-
--pnc=${PNC} \
5449
--batch_size=${BATCH_SIZE} \
5550
--max_eval_samples=-1
5651

@@ -60,7 +55,6 @@ do
6055
--dataset="librispeech" \
6156
--split="test.other" \
6257
--device=${DEVICE_ID} \
63-
--pnc=${PNC} \
6458
--batch_size=${BATCH_SIZE} \
6559
--max_eval_samples=-1
6660

@@ -70,7 +64,6 @@ do
7064
--dataset="spgispeech" \
7165
--split="test" \
7266
--device=${DEVICE_ID} \
73-
--pnc=${PNC} \
7467
--batch_size=${BATCH_SIZE} \
7568
--max_eval_samples=-1
7669

@@ -80,7 +73,6 @@ do
8073
--dataset="tedlium" \
8174
--split="test" \
8275
--device=${DEVICE_ID} \
83-
--pnc=${PNC} \
8476
--batch_size=${BATCH_SIZE} \
8577
--max_eval_samples=-1
8678

@@ -90,7 +82,6 @@ do
9082
--dataset="voxpopuli" \
9183
--split="test" \
9284
--device=${DEVICE_ID} \
93-
--pnc=${PNC} \
9485
--batch_size=${BATCH_SIZE} \
9586
--max_eval_samples=-1
9687

@@ -100,7 +91,6 @@ do
10091
--dataset="common_voice" \
10192
--split="test" \
10293
--device=${DEVICE_ID} \
103-
--pnc=${PNC} \
10494
--batch_size=${BATCH_SIZE} \
10595
--max_eval_samples=-1
10696

nemo_asr/run_eval.py

Lines changed: 103 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -8,131 +8,142 @@
88

99
from tqdm import tqdm
1010
from normalizer import data_utils
11+
import numpy as np
1112

1213
from nemo.collections.asr.models import ASRModel
14+
import time
1315

14-
DATA_CACHE_DIR = os.path.join(os.getcwd(), "audio_cache")
1516

1617
wer_metric = evaluate.load("wer")
1718

1819

19-
def dataset_iterator(dataset):
20-
for i, item in enumerate(dataset):
21-
yield {
22-
**item["audio"],
23-
"reference": item["norm_text"],
24-
"audio_filename": f"file_{i}",
25-
"sample_rate": 16_000,
26-
"sample_id": i,
27-
}
28-
29-
30-
def write_audio(buffer, cache_prefix) -> list:
31-
cache_dir = os.path.join(DATA_CACHE_DIR, cache_prefix)
32-
33-
if os.path.exists(cache_dir):
34-
shutil.rmtree(cache_dir, ignore_errors=True)
35-
36-
os.makedirs(cache_dir)
37-
38-
data_paths = []
39-
for idx, data in enumerate(buffer):
40-
fn = os.path.basename(data['audio_filename'])
41-
fn = os.path.splitext(fn)[0]
42-
path = os.path.join(cache_dir, f"{idx}_{fn}.wav")
43-
data_paths.append(path)
44-
45-
soundfile.write(path, data["array"], samplerate=data['sample_rate'])
46-
47-
return data_paths
48-
49-
50-
def pack_results(results: list, buffer, transcriptions):
51-
for sample, transcript in zip(buffer, transcriptions):
52-
result = {'reference': sample['reference'], 'pred_text': transcript}
53-
results.append(result)
54-
return results
55-
56-
57-
def buffer_audio_and_transcribe(model: ASRModel, dataset, batch_size: int, pnc:bool, cache_prefix: str, verbose: bool = True):
58-
buffer = []
59-
results = []
60-
for sample in tqdm(dataset_iterator(dataset), desc='Evaluating: Sample id', unit='', disable=not verbose):
61-
buffer.append(sample)
62-
63-
if len(buffer) == batch_size:
64-
filepaths = write_audio(buffer, cache_prefix)
65-
66-
if pnc is not None:
67-
transcriptions = model.transcribe(filepaths, batch_size=batch_size, pnc=False, verbose=False)
68-
else:
69-
transcriptions = model.transcribe(filepaths, batch_size=batch_size, verbose=False)
70-
# if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
71-
if type(transcriptions) == tuple and len(transcriptions) == 2:
72-
transcriptions = transcriptions[0]
73-
results = pack_results(results, buffer, transcriptions)
74-
buffer.clear()
75-
76-
if len(buffer) > 0:
77-
filepaths = write_audio(buffer, cache_prefix)
78-
if pnc is not None:
79-
transcriptions = model.transcribe(filepaths, batch_size=batch_size, pnc=False, verbose=False)
80-
else:
81-
transcriptions = model.transcribe(filepaths, batch_size=batch_size, verbose=False)
82-
# if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
83-
if type(transcriptions) == tuple and len(transcriptions) == 2:
84-
transcriptions = transcriptions[0]
85-
results = pack_results(results, buffer, transcriptions)
86-
buffer.clear()
87-
88-
# Delete temp cache dir
89-
if os.path.exists(DATA_CACHE_DIR):
90-
shutil.rmtree(DATA_CACHE_DIR)
91-
92-
return results
20+
def main(args):
9321

22+
DATA_CACHE_DIR = os.path.join(os.getcwd(), "audio_cache")
23+
DATASET_NAME = args.dataset
24+
SPLIT_NAME = args.split
9425

95-
def main(args):
26+
CACHE_DIR = os.path.join(DATA_CACHE_DIR, DATASET_NAME, SPLIT_NAME)
27+
if not os.path.exists(CACHE_DIR):
28+
os.makedirs(CACHE_DIR)
9629

9730
if args.device >= 0:
9831
device = torch.device(f"cuda:{args.device}")
32+
compute_dtype=torch.bfloat16
9933
else:
10034
device = torch.device("cpu")
35+
compute_dtype=torch.float32
36+
10137

10238
if args.model_id.endswith(".nemo"):
10339
asr_model = ASRModel.restore_from(args.model_id, map_location=device)
10440
else:
10541
asr_model = ASRModel.from_pretrained(args.model_id, map_location=device) # type: ASRModel
106-
asr_model.freeze()
42+
43+
asr_model.to(compute_dtype)
44+
asr_model.eval()
10745

10846
dataset = data_utils.load_data(args)
10947

48+
def download_audio_files(batch):
49+
50+
# download audio files and write the paths, transcriptions and durations to a manifest file
51+
audio_paths = []
52+
durations = []
53+
54+
for id, sample in zip(batch["id"], batch["audio"]):
55+
audio_path = os.path.join(CACHE_DIR, f"{id}.wav")
56+
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
57+
if not os.path.exists(audio_path):
58+
soundfile.write(audio_path, np.float32(sample["array"]), 16_000)
59+
audio_paths.append(audio_path)
60+
durations.append(len(sample["array"]) / 16_000)
61+
62+
63+
batch["references"] = batch["norm_text"]
64+
batch["audio_filepaths"] = audio_paths
65+
batch["durations"] = durations
66+
67+
return batch
68+
69+
11070
if args.max_eval_samples is not None and args.max_eval_samples > 0:
11171
print(f"Subsampling dataset to first {args.max_eval_samples} samples !")
11272
dataset = dataset.take(args.max_eval_samples)
11373

11474
dataset = data_utils.prepare_data(dataset)
115-
116-
predictions = []
117-
references = []
118-
119-
# run streamed inference
120-
cache_prefix = (f"{args.model_id.replace('/', '-')}-{args.dataset_path.replace('/', '')}-"
121-
f"{args.dataset.replace('/', '-')}-{args.split}")
122-
results = buffer_audio_and_transcribe(asr_model, dataset, args.batch_size, args.pnc, cache_prefix, verbose=True)
123-
for sample in results:
124-
predictions.append(data_utils.normalizer(sample["pred_text"]))
125-
references.append(sample["reference"])
126-
127-
# Write manifest results
75+
if asr_model.cfg.decoding.strategy != "beam":
76+
asr_model.cfg.decoding.strategy = "greedy_batch"
77+
asr_model.change_decoding_strategy(asr_model.cfg.decoding)
78+
79+
# prepraing the offline dataset
80+
dataset = dataset.map(download_audio_files, batch_size=args.batch_size, batched=True, remove_columns=["audio"])
81+
82+
# Write manifest from daraset batch using json and keys audio_filepath, duration, text
83+
84+
all_data = {
85+
"audio_filepaths": [],
86+
"durations": [],
87+
"references": [],
88+
}
89+
90+
data_itr = iter(dataset)
91+
for data in tqdm(data_itr, desc="Downloading Samples"):
92+
# import ipdb; ipdb.set_trace()
93+
for key in all_data:
94+
all_data[key].append(data[key])
95+
96+
# Sort audio_filepaths and references based on durations values
97+
sorted_indices = sorted(range(len(all_data["durations"])), key=lambda k: all_data["durations"][k], reverse=True)
98+
all_data["audio_filepaths"] = [all_data["audio_filepaths"][i] for i in sorted_indices]
99+
all_data["references"] = [all_data["references"][i] for i in sorted_indices]
100+
all_data["durations"] = [all_data["durations"][i] for i in sorted_indices]
101+
102+
103+
total_time = 0
104+
for _ in range(2): # warmup once and calculate rtf
105+
start_time = time.time()
106+
with torch.cuda.amp.autocast(enabled=False, dtype=compute_dtype):
107+
with torch.no_grad():
108+
if 'canary' in args.model_id:
109+
transcriptions = asr_model.transcribe(all_data["audio_filepaths"], batch_size=args.batch_size, verbose=False, pnc='no', num_workers=1)
110+
else:
111+
transcriptions = asr_model.transcribe(all_data["audio_filepaths"], batch_size=args.batch_size, verbose=False, num_workers=1)
112+
end_time = time.time()
113+
if _ == 1:
114+
total_time += end_time - start_time
115+
total_time = total_time
116+
117+
# normalize transcriptions with English normalizer
118+
if isinstance(transcriptions, tuple) and len(transcriptions) == 2:
119+
transcriptions = transcriptions[0]
120+
predictions = [data_utils.normalizer(pred) for pred in transcriptions]
121+
122+
avg_time = total_time / len(all_data["audio_filepaths"])
123+
124+
# Write manifest results (WER and RTFX)
128125
manifest_path = data_utils.write_manifest(
129-
references, predictions, args.model_id, args.dataset_path, args.dataset, args.split
126+
all_data["references"],
127+
predictions,
128+
args.model_id,
129+
args.dataset_path,
130+
args.dataset,
131+
args.split,
132+
audio_length=all_data["durations"],
133+
transcription_time=[avg_time] * len(all_data["audio_filepaths"]),
130134
)
135+
131136
print("Results saved at path:", os.path.abspath(manifest_path))
132137

133-
wer = wer_metric.compute(references=references, predictions=predictions)
138+
wer = wer_metric.compute(references=all_data['references'], predictions=predictions)
134139
wer = round(100 * wer, 2)
135140

141+
# transcription_time = sum(all_results["transcription_time"])
142+
audio_length = sum(all_data["durations"])
143+
rtfx = audio_length / total_time
144+
rtfx = round(rtfx, 2)
145+
146+
print("RTFX:", rtfx)
136147
print("WER:", wer, "%")
137148

138149

@@ -173,12 +184,6 @@ def main(args):
173184
default=None,
174185
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
175186
)
176-
parser.add_argument(
177-
"--pnc",
178-
type=bool,
179-
default=None,
180-
help="flag to indicate inferene in pnc mode for models that support punctuation and capitalization",
181-
)
182187
parser.add_argument(
183188
"--no-streaming",
184189
dest='streaming',

nemo_asr/run_fast_conformer_ctc.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
export PYTHONPATH="..":$PYTHONPATH
44

55
#considering FC-XL, FC-XXL, FC-L, C-L, C-S CTC models
6-
MODEL_IDs=("nvidia/parakeet-ctc-1.1b" "nvidia/parakeet-ctc-0.6b" "nvidia/stt_en_fastconformer_ctc_xxlarge" "nvidia/stt_en_fastconformer_ctc_xlarge" "nvidia/stt_en_fastconformer_ctc_large" "nvidia/stt_en_conformer_ctc_large" "nvidia/stt_en_conformer_ctc_small")
7-
BATCH_SIZE=8
6+
MODEL_IDs=("nvidia/parakeet-ctc-1.1b" "nvidia/parakeet-ctc-0.6b" "nvidia/stt_en_fastconformer_ctc_large" "nvidia/stt_en_conformer_ctc_large" "nvidia/stt_en_conformer_ctc_small")
7+
BATCH_SIZE=64
88
DEVICE_ID=0
99

1010
num_models=${#MODEL_IDs[@]}
@@ -48,7 +48,7 @@ do
4848
--split="test.clean" \
4949
--device=${DEVICE_ID} \
5050
--batch_size=${BATCH_SIZE} \
51-
--max_eval_samples=-1
51+
--max_eval_samples=-1
5252

5353
python run_eval.py \
5454
--model_id=${MODEL_ID} \

nemo_asr/run_fast_conformer_rnnt.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
export PYTHONPATH="..":$PYTHONPATH
44

55
#considering FC-L, FC-XL, FC-XXL, C-L and C-S RNNT models
6-
MODEL_IDs=("nvidia/parakeet-rnnt-1.1b" "nvidia/parakeet-rnnt-0.6b" "nvidia/stt_en_fastconformer_transducer_large" "nvidia/stt_en_fastconformer_transducer_xlarge" "nvidia/stt_en_fastconformer_transducer_xxlarge" "nvidia/stt_en_conformer_transducer_large" "stt_en_conformer_transducer_small")
7-
BATCH_SIZE=8
6+
MODEL_IDs=("nvidia/parakeet-tdt-1.1b" "nvidia/parakeet-rnnt-1.1b" "nvidia/parakeet-rnnt-0.6b" "nvidia/stt_en_fastconformer_transducer_large" "nvidia/stt_en_conformer_transducer_large" "stt_en_conformer_transducer_small")
7+
BATCH_SIZE=64
88
DEVICE_ID=0
99

1010
num_models=${#MODEL_IDs[@]}

0 commit comments

Comments
 (0)