|
| 1 | +""" |
| 2 | +NeMo ASR Model Profiler |
| 3 | +
|
| 4 | +This script performs a forward pass on an NeMo ASR models and measures its real-time factor (RTF). |
| 5 | +RTF is a metric used to evaluate the processing speed of ASR models. |
| 6 | +
|
| 7 | +# audio has to be a mono wav file with 16kHz sample rate |
| 8 | +
|
| 9 | +Parameters: |
| 10 | + --model: ASR model name or path to the model checkpoint file. |
| 11 | + --decoding_type: Type of decoding to use (ctc or rnnt). |
| 12 | + --gpu: GPU device to use. |
| 13 | + --batch_size: Batch size to use for inference. |
| 14 | + --nbatches: Total number of batches to process. |
| 15 | + --warmup_batches: Number of batches to skip as warmup. |
| 16 | + --audio: Path to the input audio file for ASR. |
| 17 | + --audio_maxlen: Maximum duration of audio to process (in seconds). |
| 18 | + --precision: Model precision (16, 32, or bf16). |
| 19 | + --cudnn_benchmark: Enable cuDNN benchmarking. |
| 20 | + --log: Enable logging. |
| 21 | +
|
| 22 | +Example: |
| 23 | + python calculate_rtf.py --model stt_en_conformer_ctc_large --decoding_type ctc --gpu 0 --batch_size 1 --nbatches 5 --warmup_batches 5 --audio /path/to/audio.wav --audio_maxlen 600 --precision bf16 --cudnn_benchmark |
| 24 | +""" |
| 25 | + |
| 26 | +import time |
| 27 | +import argparse |
| 28 | +from tqdm import tqdm |
| 29 | +import torch |
| 30 | +from omegaconf import OmegaConf |
| 31 | +import copy |
| 32 | +import sys |
| 33 | +import soundfile as sf |
| 34 | +import numpy as np |
| 35 | + |
| 36 | +from nemo.utils import logging |
| 37 | +from contextlib import nullcontext |
| 38 | +from nemo.collections.asr.models import ASRModel |
| 39 | + |
| 40 | + |
| 41 | + |
| 42 | +parser = argparse.ArgumentParser(description='model forward pass profiler / performance tester.') |
| 43 | +parser.add_argument("--model", default='stt_en_conformer_ctc_large', type=str, help="ASR model") |
| 44 | +parser.add_argument("--decoding_type", default='ctc', type=str, help="Encoding type (bpe or char)") |
| 45 | +parser.add_argument("--gpu", default=0, type=int, help="GPU device to use") |
| 46 | +parser.add_argument("--batch_size", default=1, type=int, help="batch size to use") |
| 47 | +parser.add_argument("--nbatches", default=5, type=int, help="Total Number of batches to process") |
| 48 | +parser.add_argument("--warmup_batches", default=5, type=int, help="Number of batches to skip as warmup") |
| 49 | +parser.add_argument("--audio", default="/disk3/datasets/speech-datasets/earnings22/media/4469669.wav", type=str, help="wav file to use") |
| 50 | +parser.add_argument("--audio_maxlen", default=16, type=float, help="cut the file at given length if it is longer") |
| 51 | +parser.add_argument("--precision", default='bf16', type=str, help="precision: 16/32/bf16") |
| 52 | +parser.add_argument("--cudnn_benchmark", dest="enable_cudnn_bench", action="store_true", help="toggle cudnn benchmarking", default=True) |
| 53 | +parser.add_argument("--log", dest="log", action="store_true", help="toggle logging", default=True) |
| 54 | + |
| 55 | +args = parser.parse_args() |
| 56 | + |
| 57 | +if args.log: |
| 58 | + # INFO |
| 59 | + logging.setLevel(20) |
| 60 | +else: |
| 61 | + logging.setLevel(0) |
| 62 | + |
| 63 | +if args.enable_cudnn_bench: |
| 64 | + torch.backends.cudnn.benchmark=True |
| 65 | + |
| 66 | +PRECISION = args.precision |
| 67 | +WAV = args.audio |
| 68 | +audio_maxlen = args.audio_maxlen |
| 69 | +MODEL = args.model |
| 70 | +batch_size = args.batch_size |
| 71 | +nbatches = args.nbatches |
| 72 | +warmup_batches = args.warmup_batches |
| 73 | +decoding_type = args.decoding_type |
| 74 | + |
| 75 | +DEVICE=torch.device(args.gpu) |
| 76 | + |
| 77 | +if PRECISION != 'bf16' and PRECISION != '16' and PRECISION != '32': |
| 78 | + logging.error(f'unknown precision: {PRECISION}') |
| 79 | + sys.exit(1) |
| 80 | + |
| 81 | +logging.info(f'precision: {PRECISION}') |
| 82 | +logging.info(f'WAV: {WAV}') |
| 83 | +logging.info(f'AUDIO MAXLEN: {audio_maxlen}') |
| 84 | +logging.info(f'MODEL: {MODEL}') |
| 85 | +logging.info(f'batch_size: {batch_size}') |
| 86 | +logging.info(f'num batches: {nbatches}') |
| 87 | +logging.info(f'cudnn_benchmark: {args.enable_cudnn_bench}') |
| 88 | + |
| 89 | + |
| 90 | +def get_samples(audio_file, audio_maxlen, target_sr=16000): |
| 91 | + with sf.SoundFile(audio_file, 'r') as f: |
| 92 | + dtype = 'int16' |
| 93 | + sample_rate = f.samplerate |
| 94 | + samples = f.read(dtype=dtype) |
| 95 | + if sample_rate != target_sr: |
| 96 | + samples = librosa.core.resample(samples, sample_rate, target_sr) |
| 97 | + samples = samples.astype('float32') / 32768 |
| 98 | + samples = samples.transpose() |
| 99 | + sample_length = samples.shape[0] |
| 100 | + if sample_length > audio_maxlen * target_sr: |
| 101 | + logging.info(f'resizing audio sample from {sample_length / target_sr} to maxlen of {audio_maxlen}') |
| 102 | + sample_length = int(audio_maxlen * target_sr) |
| 103 | + samples = samples[:sample_length] |
| 104 | + logging.info(f'new sample lengh: {samples.shape[0]}') |
| 105 | + else: |
| 106 | + pad_length = int(audio_maxlen * target_sr) - sample_length |
| 107 | + logging.info(f'padding audio sample from {sample_length / target_sr} to maxlen of {audio_maxlen}') |
| 108 | + samples = np.pad(samples, (0, pad_length), 'constant', constant_values=(0, 0)) |
| 109 | + sample_length = int(audio_maxlen * target_sr) |
| 110 | + |
| 111 | + return samples, sample_length |
| 112 | + |
| 113 | +def preprocess_audio(preprocessor, audio, device): |
| 114 | + audio_signal = torch.from_numpy(audio).unsqueeze_(0).to(device) |
| 115 | + |
| 116 | + audio_signal_len = torch.Tensor([audio.shape[0]]).to(device) |
| 117 | + processed_signal, processed_signal_length = preprocessor( |
| 118 | + input_signal=audio_signal, length=audio_signal_len |
| 119 | + ) |
| 120 | + return processed_signal, processed_signal_length |
| 121 | + |
| 122 | +def extract_preprocessor(model, device): |
| 123 | + cfg = copy.deepcopy(model._cfg) |
| 124 | + OmegaConf.set_struct(cfg.preprocessor, False) |
| 125 | + cfg.preprocessor.dither = 0.0 |
| 126 | + cfg.preprocessor.pad_to = 0 |
| 127 | + preprocessor = model.from_config_dict(cfg.preprocessor) |
| 128 | + return preprocessor.to(device) |
| 129 | + |
| 130 | +def main(): |
| 131 | + |
| 132 | + if MODEL.endswith('.nemo'): |
| 133 | + asr_model = ASRModel.restore_from(MODEL) |
| 134 | + else: |
| 135 | + asr_model = ASRModel.from_pretrained(MODEL) |
| 136 | + |
| 137 | + asr_model.to(DEVICE) |
| 138 | + asr_model.encoder.eval() |
| 139 | + asr_model.encoder.freeze() |
| 140 | + asr_model._prepare_for_export() |
| 141 | + |
| 142 | + processor = extract_preprocessor(asr_model, DEVICE) |
| 143 | + |
| 144 | + input_example, input_example_length = get_samples(WAV, audio_maxlen) |
| 145 | + logging.info(f'processed example shape: {input_example.shape}') |
| 146 | + logging.info(f'processed example length shape: {input_example_length}') |
| 147 | + processed_example, processed_example_length = preprocess_audio(processor, input_example, DEVICE) |
| 148 | + processed_example = processed_example.repeat(batch_size, 1, 1) |
| 149 | + processed_example_length = processed_example_length.repeat(batch_size) |
| 150 | + logging.info(f'processed example shape: {processed_example.size()}') |
| 151 | + logging.info(f'processed example length shape: {processed_example_length.size()}') |
| 152 | + |
| 153 | + profiling_context = nullcontext() |
| 154 | + # if FP16: |
| 155 | + if PRECISION == '16': |
| 156 | + precision_context = torch.cuda.amp.autocast() |
| 157 | + elif PRECISION == 'bf16': |
| 158 | + precision_context = torch.cuda.amp.autocast(dtype=torch.bfloat16) |
| 159 | + elif PRECISION == '32': |
| 160 | + pass |
| 161 | + else: |
| 162 | + logging.error(f'unknown precision: {PRECISION}') |
| 163 | + sys.exit(1) |
| 164 | + |
| 165 | + |
| 166 | + if decoding_type == 'ctc': |
| 167 | + asr_model.change_decoding_strategy(decoding_cfg=None) |
| 168 | + |
| 169 | + logging.info(f"running {nbatches} batches; with {warmup_batches} batches warmup; batch_size: {batch_size}") |
| 170 | + rtfs=[] |
| 171 | + for i in range(3): # average over 3 runs |
| 172 | + total_time = 0 |
| 173 | + with profiling_context: |
| 174 | + with precision_context: |
| 175 | + with torch.no_grad(): |
| 176 | + for i in tqdm(range(nbatches + warmup_batches)): |
| 177 | + |
| 178 | + start = time.time() |
| 179 | + if decoding_type == 'rnnt': |
| 180 | + enc_out, enc_len = asr_model.encoder.forward(audio_signal=processed_example, length=processed_example_length) |
| 181 | + dec_out, dec_len = asr_model.decoding.rnnt_decoder_predictions_tensor( |
| 182 | + encoder_output=enc_out, encoded_lengths=enc_len, return_hypotheses=False |
| 183 | + ) |
| 184 | + else: |
| 185 | + enc_out, enc_len, greedy_predictions = asr_model.forward(processed_signal=processed_example, processed_signal_length=processed_example_length) |
| 186 | + dec_out, dec_len = asr_model.decoding.ctc_decoder_predictions_tensor( |
| 187 | + enc_out, decoder_lengths=enc_len, return_hypotheses=False |
| 188 | + ) |
| 189 | + torch.cuda.synchronize() |
| 190 | + end = time.time() |
| 191 | + if i >= warmup_batches: |
| 192 | + total_time += end - start |
| 193 | + |
| 194 | + |
| 195 | + rtf = (total_time/nbatches) / (input_example_length / 16000) |
| 196 | + |
| 197 | + rtfs.append(rtf) |
| 198 | + |
| 199 | + print(f'RTF: {rtfs}') |
| 200 | + rtf = sum(rtfs)/len(rtfs) |
| 201 | + sys.stdout.write(f'{rtf:.4f}\n') |
| 202 | + |
| 203 | +if __name__ == '__main__': |
| 204 | + main() |
0 commit comments