Skip to content

Commit 68b9be9

Browse files
committed
Fix issue in NeMo eval for parallel runs overwriting the cache dir
Signed-off-by: smajumdar <[email protected]>
1 parent 9caef99 commit 68b9be9

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

nemo_asr/run_eval.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,19 @@ def dataset_iterator(dataset):
2727
}
2828

2929

30-
def write_audio(buffer) -> list:
31-
if os.path.exists(DATA_CACHE_DIR):
32-
shutil.rmtree(DATA_CACHE_DIR, ignore_errors=True)
30+
def write_audio(buffer, cache_prefix) -> list:
31+
cache_dir = os.path.join(DATA_CACHE_DIR, cache_prefix)
32+
33+
if os.path.exists(cache_dir):
34+
shutil.rmtree(cache_dir, ignore_errors=True)
3335

34-
os.makedirs(DATA_CACHE_DIR)
36+
os.makedirs(cache_dir)
3537

3638
data_paths = []
3739
for idx, data in enumerate(buffer):
3840
fn = os.path.basename(data['audio_filename'])
3941
fn = os.path.splitext(fn)[0]
40-
path = os.path.join(DATA_CACHE_DIR, f"{idx}_{fn}.wav")
42+
path = os.path.join(cache_dir, f"{idx}_{fn}.wav")
4143
data_paths.append(path)
4244

4345
soundfile.write(path, data["array"], samplerate=data['sample_rate'])
@@ -52,14 +54,14 @@ def pack_results(results: list, buffer, transcriptions):
5254
return results
5355

5456

55-
def buffer_audio_and_transcribe(model: ASRModel, dataset, batch_size: int, verbose: bool = True):
57+
def buffer_audio_and_transcribe(model: ASRModel, dataset, batch_size: int, cache_prefix: str, verbose: bool = True):
5658
buffer = []
5759
results = []
5860
for sample in tqdm(dataset_iterator(dataset), desc='Evaluating: Sample id', unit='', disable=not verbose):
5961
buffer.append(sample)
6062

6163
if len(buffer) == batch_size:
62-
filepaths = write_audio(buffer)
64+
filepaths = write_audio(buffer, cache_prefix)
6365
transcriptions = model.transcribe(filepaths, batch_size=batch_size, verbose=False)
6466
# if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
6567
if type(transcriptions) == tuple and len(transcriptions) == 2:
@@ -68,7 +70,7 @@ def buffer_audio_and_transcribe(model: ASRModel, dataset, batch_size: int, verbo
6870
buffer.clear()
6971

7072
if len(buffer) > 0:
71-
filepaths = write_audio(buffer)
73+
filepaths = write_audio(buffer, cache_prefix)
7274
transcriptions = model.transcribe(filepaths, batch_size=batch_size, verbose=False)
7375
# if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
7476
if type(transcriptions) == tuple and len(transcriptions) == 2:
@@ -105,7 +107,9 @@ def main(args):
105107
references = []
106108

107109
# run streamed inference
108-
results = buffer_audio_and_transcribe(asr_model, dataset, args.batch_size, verbose=True)
110+
cache_prefix = (f"{args.model_id.replace('/', '-')}-{args.dataset_path.replace('/', '')}-"
111+
f"{args.dataset.replace('/', '-')}-{args.split}")
112+
results = buffer_audio_and_transcribe(asr_model, dataset, args.batch_size, cache_prefix, verbose=True)
109113
for sample in results:
110114
predictions.append(data_utils.normalizer(sample["pred_text"]))
111115
references.append(sample["reference"])

0 commit comments

Comments
 (0)