|
8 | 8 |
|
9 | 9 | from tqdm import tqdm |
10 | 10 | from normalizer import data_utils |
| 11 | +import numpy as np |
11 | 12 |
|
12 | 13 | from nemo.collections.asr.models import ASRModel |
| 14 | +import time |
13 | 15 |
|
14 | | -DATA_CACHE_DIR = os.path.join(os.getcwd(), "audio_cache") |
15 | 16 |
|
16 | 17 | wer_metric = evaluate.load("wer") |
17 | 18 |
|
18 | 19 |
|
19 | | -def dataset_iterator(dataset): |
20 | | - for i, item in enumerate(dataset): |
21 | | - yield { |
22 | | - **item["audio"], |
23 | | - "reference": item["norm_text"], |
24 | | - "audio_filename": f"file_{i}", |
25 | | - "sample_rate": 16_000, |
26 | | - "sample_id": i, |
27 | | - } |
28 | | - |
29 | | - |
30 | | -def write_audio(buffer, cache_prefix) -> list: |
31 | | - cache_dir = os.path.join(DATA_CACHE_DIR, cache_prefix) |
32 | | - |
33 | | - if os.path.exists(cache_dir): |
34 | | - shutil.rmtree(cache_dir, ignore_errors=True) |
35 | | - |
36 | | - os.makedirs(cache_dir) |
37 | | - |
38 | | - data_paths = [] |
39 | | - for idx, data in enumerate(buffer): |
40 | | - fn = os.path.basename(data['audio_filename']) |
41 | | - fn = os.path.splitext(fn)[0] |
42 | | - path = os.path.join(cache_dir, f"{idx}_{fn}.wav") |
43 | | - data_paths.append(path) |
44 | | - |
45 | | - soundfile.write(path, data["array"], samplerate=data['sample_rate']) |
46 | | - |
47 | | - return data_paths |
48 | | - |
49 | | - |
50 | | -def pack_results(results: list, buffer, transcriptions): |
51 | | - for sample, transcript in zip(buffer, transcriptions): |
52 | | - result = {'reference': sample['reference'], 'pred_text': transcript} |
53 | | - results.append(result) |
54 | | - return results |
55 | | - |
56 | | - |
57 | | -def buffer_audio_and_transcribe(model: ASRModel, dataset, batch_size: int, pnc:bool, cache_prefix: str, verbose: bool = True): |
58 | | - buffer = [] |
59 | | - results = [] |
60 | | - for sample in tqdm(dataset_iterator(dataset), desc='Evaluating: Sample id', unit='', disable=not verbose): |
61 | | - buffer.append(sample) |
62 | | - |
63 | | - if len(buffer) == batch_size: |
64 | | - filepaths = write_audio(buffer, cache_prefix) |
65 | | - |
66 | | - if pnc is not None: |
67 | | - transcriptions = model.transcribe(filepaths, batch_size=batch_size, pnc=False, verbose=False) |
68 | | - else: |
69 | | - transcriptions = model.transcribe(filepaths, batch_size=batch_size, verbose=False) |
70 | | - # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis |
71 | | - if type(transcriptions) == tuple and len(transcriptions) == 2: |
72 | | - transcriptions = transcriptions[0] |
73 | | - results = pack_results(results, buffer, transcriptions) |
74 | | - buffer.clear() |
75 | | - |
76 | | - if len(buffer) > 0: |
77 | | - filepaths = write_audio(buffer, cache_prefix) |
78 | | - if pnc is not None: |
79 | | - transcriptions = model.transcribe(filepaths, batch_size=batch_size, pnc=False, verbose=False) |
80 | | - else: |
81 | | - transcriptions = model.transcribe(filepaths, batch_size=batch_size, verbose=False) |
82 | | - # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis |
83 | | - if type(transcriptions) == tuple and len(transcriptions) == 2: |
84 | | - transcriptions = transcriptions[0] |
85 | | - results = pack_results(results, buffer, transcriptions) |
86 | | - buffer.clear() |
87 | | - |
88 | | - # Delete temp cache dir |
89 | | - if os.path.exists(DATA_CACHE_DIR): |
90 | | - shutil.rmtree(DATA_CACHE_DIR) |
91 | | - |
92 | | - return results |
| 20 | +def main(args): |
93 | 21 |
|
| 22 | + DATA_CACHE_DIR = os.path.join(os.getcwd(), "audio_cache") |
| 23 | + DATASET_NAME = args.dataset |
| 24 | + SPLIT_NAME = args.split |
94 | 25 |
|
95 | | -def main(args): |
| 26 | + CACHE_DIR = os.path.join(DATA_CACHE_DIR, DATASET_NAME, SPLIT_NAME) |
| 27 | + if not os.path.exists(CACHE_DIR): |
| 28 | + os.makedirs(CACHE_DIR) |
96 | 29 |
|
97 | 30 | if args.device >= 0: |
98 | 31 | device = torch.device(f"cuda:{args.device}") |
| 32 | + compute_dtype=torch.bfloat16 |
99 | 33 | else: |
100 | 34 | device = torch.device("cpu") |
| 35 | + compute_dtype=torch.float32 |
| 36 | + |
101 | 37 |
|
102 | 38 | if args.model_id.endswith(".nemo"): |
103 | 39 | asr_model = ASRModel.restore_from(args.model_id, map_location=device) |
104 | 40 | else: |
105 | 41 | asr_model = ASRModel.from_pretrained(args.model_id, map_location=device) # type: ASRModel |
106 | | - asr_model.freeze() |
| 42 | + |
| 43 | + asr_model.to(compute_dtype) |
| 44 | + asr_model.eval() |
107 | 45 |
|
108 | 46 | dataset = data_utils.load_data(args) |
109 | 47 |
|
| 48 | + def download_audio_files(batch): |
| 49 | + |
| 50 | + # download audio files and write the paths, transcriptions and durations to a manifest file |
| 51 | + audio_paths = [] |
| 52 | + durations = [] |
| 53 | + |
| 54 | + for id, sample in zip(batch["id"], batch["audio"]): |
| 55 | + audio_path = os.path.join(CACHE_DIR, f"{id}.wav") |
| 56 | + os.makedirs(os.path.dirname(audio_path), exist_ok=True) |
| 57 | + if not os.path.exists(audio_path): |
| 58 | + soundfile.write(audio_path, np.float32(sample["array"]), 16_000) |
| 59 | + audio_paths.append(audio_path) |
| 60 | + durations.append(len(sample["array"]) / 16_000) |
| 61 | + |
| 62 | + |
| 63 | + batch["references"] = batch["norm_text"] |
| 64 | + batch["audio_filepaths"] = audio_paths |
| 65 | + batch["durations"] = durations |
| 66 | + |
| 67 | + return batch |
| 68 | + |
| 69 | + |
110 | 70 | if args.max_eval_samples is not None and args.max_eval_samples > 0: |
111 | 71 | print(f"Subsampling dataset to first {args.max_eval_samples} samples !") |
112 | 72 | dataset = dataset.take(args.max_eval_samples) |
113 | 73 |
|
114 | 74 | dataset = data_utils.prepare_data(dataset) |
115 | | - |
116 | | - predictions = [] |
117 | | - references = [] |
118 | | - |
119 | | - # run streamed inference |
120 | | - cache_prefix = (f"{args.model_id.replace('/', '-')}-{args.dataset_path.replace('/', '')}-" |
121 | | - f"{args.dataset.replace('/', '-')}-{args.split}") |
122 | | - results = buffer_audio_and_transcribe(asr_model, dataset, args.batch_size, args.pnc, cache_prefix, verbose=True) |
123 | | - for sample in results: |
124 | | - predictions.append(data_utils.normalizer(sample["pred_text"])) |
125 | | - references.append(sample["reference"]) |
126 | | - |
127 | | - # Write manifest results |
| 75 | + if asr_model.cfg.decoding.strategy != "beam": |
| 76 | + asr_model.cfg.decoding.strategy = "greedy_batch" |
| 77 | + asr_model.change_decoding_strategy(asr_model.cfg.decoding) |
| 78 | + |
| 79 | + # prepraing the offline dataset |
| 80 | + dataset = dataset.map(download_audio_files, batch_size=args.batch_size, batched=True, remove_columns=["audio"]) |
| 81 | + |
| 82 | + # Write manifest from daraset batch using json and keys audio_filepath, duration, text |
| 83 | + |
| 84 | + all_data = { |
| 85 | + "audio_filepaths": [], |
| 86 | + "durations": [], |
| 87 | + "references": [], |
| 88 | + } |
| 89 | + |
| 90 | + data_itr = iter(dataset) |
| 91 | + for data in tqdm(data_itr, desc="Downloading Samples"): |
| 92 | + # import ipdb; ipdb.set_trace() |
| 93 | + for key in all_data: |
| 94 | + all_data[key].append(data[key]) |
| 95 | + |
| 96 | + # Sort audio_filepaths and references based on durations values |
| 97 | + sorted_indices = sorted(range(len(all_data["durations"])), key=lambda k: all_data["durations"][k], reverse=True) |
| 98 | + all_data["audio_filepaths"] = [all_data["audio_filepaths"][i] for i in sorted_indices] |
| 99 | + all_data["references"] = [all_data["references"][i] for i in sorted_indices] |
| 100 | + all_data["durations"] = [all_data["durations"][i] for i in sorted_indices] |
| 101 | + |
| 102 | + |
| 103 | + total_time = 0 |
| 104 | + for _ in range(2): # warmup once and calculate rtf |
| 105 | + start_time = time.time() |
| 106 | + with torch.cuda.amp.autocast(enabled=False, dtype=compute_dtype): |
| 107 | + with torch.no_grad(): |
| 108 | + if 'canary' in args.model_id: |
| 109 | + transcriptions = asr_model.transcribe(all_data["audio_filepaths"], batch_size=args.batch_size, verbose=False, pnc='no', num_workers=1) |
| 110 | + else: |
| 111 | + transcriptions = asr_model.transcribe(all_data["audio_filepaths"], batch_size=args.batch_size, verbose=False, num_workers=1) |
| 112 | + end_time = time.time() |
| 113 | + if _ == 1: |
| 114 | + total_time += end_time - start_time |
| 115 | + total_time = total_time |
| 116 | + |
| 117 | + # normalize transcriptions with English normalizer |
| 118 | + if isinstance(transcriptions, tuple) and len(transcriptions) == 2: |
| 119 | + transcriptions = transcriptions[0] |
| 120 | + predictions = [data_utils.normalizer(pred) for pred in transcriptions] |
| 121 | + |
| 122 | + avg_time = total_time / len(all_data["audio_filepaths"]) |
| 123 | + |
| 124 | + # Write manifest results (WER and RTFX) |
128 | 125 | manifest_path = data_utils.write_manifest( |
129 | | - references, predictions, args.model_id, args.dataset_path, args.dataset, args.split |
| 126 | + all_data["references"], |
| 127 | + predictions, |
| 128 | + args.model_id, |
| 129 | + args.dataset_path, |
| 130 | + args.dataset, |
| 131 | + args.split, |
| 132 | + audio_length=all_data["durations"], |
| 133 | + transcription_time=[avg_time] * len(all_data["audio_filepaths"]), |
130 | 134 | ) |
| 135 | + |
131 | 136 | print("Results saved at path:", os.path.abspath(manifest_path)) |
132 | 137 |
|
133 | | - wer = wer_metric.compute(references=references, predictions=predictions) |
| 138 | + wer = wer_metric.compute(references=all_data['references'], predictions=predictions) |
134 | 139 | wer = round(100 * wer, 2) |
135 | 140 |
|
| 141 | + # transcription_time = sum(all_results["transcription_time"]) |
| 142 | + audio_length = sum(all_data["durations"]) |
| 143 | + rtfx = audio_length / total_time |
| 144 | + rtfx = round(rtfx, 2) |
| 145 | + |
| 146 | + print("RTFX:", rtfx) |
136 | 147 | print("WER:", wer, "%") |
137 | 148 |
|
138 | 149 |
|
@@ -173,12 +184,6 @@ def main(args): |
173 | 184 | default=None, |
174 | 185 | help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", |
175 | 186 | ) |
176 | | - parser.add_argument( |
177 | | - "--pnc", |
178 | | - type=bool, |
179 | | - default=None, |
180 | | - help="flag to indicate inferene in pnc mode for models that support punctuation and capitalization", |
181 | | - ) |
182 | 187 | parser.add_argument( |
183 | 188 | "--no-streaming", |
184 | 189 | dest='streaming', |
|
0 commit comments