Skip to content

Commit b7ae64b

Browse files
committed
added pnc flag for canary nemo asr eval
Signed-off-by: KunalDhawan <[email protected]>
1 parent da27d35 commit b7ae64b

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

nemo_asr/run_canary.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
export PYTHONPATH="..":$PYTHONPATH
44

55
MODEL_IDs=("nvidia/canary-1b")
6+
PNC=False
67
BATCH_SIZE=64
78
DEVICE_ID=0
89

@@ -19,6 +20,7 @@ do
1920
--dataset="ami" \
2021
--split="test" \
2122
--device=${DEVICE_ID} \
23+
--pnc=${PNC} \
2224
--batch_size=${BATCH_SIZE} \
2325
--max_eval_samples=-1
2426

@@ -28,6 +30,7 @@ do
2830
--dataset="earnings22" \
2931
--split="test" \
3032
--device=${DEVICE_ID} \
33+
--pnc=${PNC} \
3134
--batch_size=${BATCH_SIZE} \
3235
--max_eval_samples=-1
3336

@@ -37,6 +40,7 @@ do
3740
--dataset="gigaspeech" \
3841
--split="test" \
3942
--device=${DEVICE_ID} \
43+
--pnc=${PNC} \
4044
--batch_size=${BATCH_SIZE} \
4145
--max_eval_samples=-1
4246

@@ -46,6 +50,7 @@ do
4650
--dataset="librispeech" \
4751
--split="test.clean" \
4852
--device=${DEVICE_ID} \
53+
--pnc=${PNC} \
4954
--batch_size=${BATCH_SIZE} \
5055
--max_eval_samples=-1
5156

@@ -55,6 +60,7 @@ do
5560
--dataset="librispeech" \
5661
--split="test.other" \
5762
--device=${DEVICE_ID} \
63+
--pnc=${PNC} \
5864
--batch_size=${BATCH_SIZE} \
5965
--max_eval_samples=-1
6066

@@ -64,6 +70,7 @@ do
6470
--dataset="spgispeech" \
6571
--split="test" \
6672
--device=${DEVICE_ID} \
73+
--pnc=${PNC} \
6774
--batch_size=${BATCH_SIZE} \
6875
--max_eval_samples=-1
6976

@@ -73,6 +80,7 @@ do
7380
--dataset="tedlium" \
7481
--split="test" \
7582
--device=${DEVICE_ID} \
83+
--pnc=${PNC} \
7684
--batch_size=${BATCH_SIZE} \
7785
--max_eval_samples=-1
7886

@@ -82,6 +90,7 @@ do
8290
--dataset="voxpopuli" \
8391
--split="test" \
8492
--device=${DEVICE_ID} \
93+
--pnc=${PNC} \
8594
--batch_size=${BATCH_SIZE} \
8695
--max_eval_samples=-1
8796

@@ -91,6 +100,7 @@ do
91100
--dataset="common_voice" \
92101
--split="test" \
93102
--device=${DEVICE_ID} \
103+
--pnc=${PNC} \
94104
--batch_size=${BATCH_SIZE} \
95105
--max_eval_samples=-1
96106

nemo_asr/run_eval.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,19 @@ def pack_results(results: list, buffer, transcriptions):
5454
return results
5555

5656

57-
def buffer_audio_and_transcribe(model: ASRModel, dataset, batch_size: int, cache_prefix: str, verbose: bool = True):
57+
def buffer_audio_and_transcribe(model: ASRModel, dataset, batch_size: int, pnc:bool, cache_prefix: str, verbose: bool = True):
5858
buffer = []
5959
results = []
6060
for sample in tqdm(dataset_iterator(dataset), desc='Evaluating: Sample id', unit='', disable=not verbose):
6161
buffer.append(sample)
6262

6363
if len(buffer) == batch_size:
6464
filepaths = write_audio(buffer, cache_prefix)
65-
transcriptions = model.transcribe(filepaths, batch_size=batch_size, verbose=False)
65+
66+
if pnc is not None:
67+
transcriptions = model.transcribe(filepaths, batch_size=batch_size, pnc=False, verbose=False)
68+
else:
69+
transcriptions = model.transcribe(filepaths, batch_size=batch_size, verbose=False)
6670
# if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
6771
if type(transcriptions) == tuple and len(transcriptions) == 2:
6872
transcriptions = transcriptions[0]
@@ -71,7 +75,10 @@ def buffer_audio_and_transcribe(model: ASRModel, dataset, batch_size: int, cache
7175

7276
if len(buffer) > 0:
7377
filepaths = write_audio(buffer, cache_prefix)
74-
transcriptions = model.transcribe(filepaths, batch_size=batch_size, verbose=False)
78+
if pnc is not None:
79+
transcriptions = model.transcribe(filepaths, batch_size=batch_size, pnc=False, verbose=False)
80+
else:
81+
transcriptions = model.transcribe(filepaths, batch_size=batch_size, verbose=False)
7582
# if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
7683
if type(transcriptions) == tuple and len(transcriptions) == 2:
7784
transcriptions = transcriptions[0]
@@ -112,7 +119,7 @@ def main(args):
112119
# run streamed inference
113120
cache_prefix = (f"{args.model_id.replace('/', '-')}-{args.dataset_path.replace('/', '')}-"
114121
f"{args.dataset.replace('/', '-')}-{args.split}")
115-
results = buffer_audio_and_transcribe(asr_model, dataset, args.batch_size, cache_prefix, verbose=True)
122+
results = buffer_audio_and_transcribe(asr_model, dataset, args.batch_size, args.pnc, cache_prefix, verbose=True)
116123
for sample in results:
117124
predictions.append(data_utils.normalizer(sample["pred_text"]))
118125
references.append(sample["reference"])
@@ -166,6 +173,12 @@ def main(args):
166173
default=None,
167174
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
168175
)
176+
parser.add_argument(
177+
"--pnc",
178+
type=bool,
179+
default=None,
180+
help="flag to indicate inferene in pnc mode for models that support punctuation and capitalization",
181+
)
169182
parser.add_argument(
170183
"--no-streaming",
171184
dest='streaming',

0 commit comments

Comments
 (0)