diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 4d06aaf..c6970cc --- a/README.md +++ b/README.md @@ -140,6 +140,31 @@ print(results) - FireRedASR-AED supports audio input up to 60s. Input longer than 60s may cause hallucination issues, and input exceeding 200s will trigger positional encoding errors. - FireRedASR-LLM supports audio input up to 30s. The behavior for longer input is currently unknown. +### FireRedASR-AED Optimization with ROCm +1. Build docker image with `docker/Dockerfile.rocm` to setup environemt +``` +docker build --network=host -f docker/Dockerfile.rocm -t rocm/firered-asr-opt +``` + +2. Launch docker container +``` +docker run -it --ipc=host --network=host --privileged --security-opt seccomp=unconfined --cap-add=CAP_SYS_ADMIN --cap-add=SYS_PTRACE --device=/dev/kfd --device=/dev/dri --device=/dev/mem rocm/firered-asr-opt +``` + +3. Run performance test with native MHA (baseline) +```python +ATTENTION_BACKEND="NATIVE" python examples/benchmark_firered_asr.py +``` + +4. Run performance test with MHA using torch SDPA +```python +ATTENTION_BACKEND="SDPA" python examples/benchmark_firered_asr.py +``` + +5. Run performance test with MHA using xFormers +```python +ATTENTION_BACKEND="XFORMERS" python examples/benchmark_firered_asr.py +``` ## Acknowledgements Thanks to the following open-source works: diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm new file mode 100755 index 0000000..a099804 --- /dev/null +++ b/docker/Dockerfile.rocm @@ -0,0 +1,28 @@ +ARG BASE_IMAGE=rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0 +FROM ${BASE_IMAGE} AS base + +ARG PYTORCH_ROCM_ARCH=gfx942 +ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} +ENV HIP_ARCHITECTURE=${PYTORCH_ROCM_ARCH} +ENV ROCM_PATH=/opt/rocm +ENV XFORMERS_CK_FLASH_ATTN=1 +ARG PYTHON_VERSION=3.12 +ENV DEBIAN_FRONTEND=noninteractive + +WORKDIR /root + +RUN set -ex && usermod -a -G video $(whoami) + +RUN python3 -m pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython + +RUN git clone https://github.com/FireRedTeam/FireRedASR /root/FireRedASR \ + && cd /root/FireRedASR \ + && python3 -m pip install -r requirements.txt + +RUN git clone https://github.com/ROCm/xformers.git /root/xformers \ + && cd /root/xformers \ + && git checkout 5f0419a \ + && git submodule update --init --recursive \ + && PYTORCH_ROCM_ARCH=$PYTORCH_ROCM_ARCH HIP_ARCHITECTURE=$HIP_ARCHITECTURE XFORMERS_CK_FLASH_ATTN=$XFORMERS_CK_FLASH_ATTN python3 setup.py install + +WORKDIR /root/FireRedASR \ No newline at end of file diff --git a/examples/benchmark_firered_asr.py b/examples/benchmark_firered_asr.py new file mode 100755 index 0000000..4acf51c --- /dev/null +++ b/examples/benchmark_firered_asr.py @@ -0,0 +1,187 @@ +import os +import time +import torch +import numpy as np +from tqdm import tqdm +import json +import librosa +import soundfile as sf +import argparse +import torch + +from fireredasr.models.fireredasr import FireRedAsr + +from torch.profiler import profile as torch_profiler +from torch.profiler import ProfilerActivity + + +ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "XFORMERS") # Option: "NATIVE", "SDPA", "XFORMERS" + + +def load_model(model_path="pretrained_models/FireRedASR-AED-L"): + print("==========Load model:========") + model = FireRedAsr.from_pretrained("aed", model_path) + model.model.to(torch.float16) + model.model.cuda() + model.model.eval() + + return model + +def load_audio(wav_path): + print("==========load audio:=========") + audio, sr = sf.read(wav_path,dtype=np.float32) + print(len(audio), audio.dtype) + if sr != 16000: + audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) + return audio + +def run(model, batch_wav_path, warmpup=2, trials=10, enable_profile=False, offset=0): + batch_uttid = list(range(offset, offset + len(batch_wav_path), 1)) + results = None + total_dur = None + + preprocess_start = time.time() + feats, lengths, durs = model.feat_extractor(batch_wav_path) + feats = feats.to(torch.float16) + feats, lengths = feats.cuda(), lengths.cuda() + preprocess_dur = time.time() - preprocess_start + print(f"preprocess duration: {preprocess_dur:.3f} s") + total_dur = sum(durs) + avg_audio_dur_per_sample = total_dur / len(durs) + print(f"total input audio duration: {total_dur:.3f} s, avg input audio duration per sample: {avg_audio_dur_per_sample:.3f} s") + # Warmup + print("==========warmup========") + for _ in range(warmpup): + with torch.no_grad(): + _ = model.model.transcribe(feats, lengths, beam_size=3, nbest=1, decode_max_len=0, softmax_smoothing=1.25, length_penalty=0.6, eos_penalty=1.0) # repetition_penalty=1.0, decode_min_len=0, temperature=1.0 used only for llm + + # Benchmark + print("==========start benchmark========") + total_time = 0 + results = [] + rtf_list = [] + if enable_profile: + warmup=1 + trials=1 + for _ in tqdm(range(trials)): + start = time.time() + with torch.no_grad(): + if enable_profile: + with torch_profiler( + activities=[ProfilerActivity.CPU, + ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True, + profile_memory=False) as prof: + hyps = model.model.transcribe(feats, lengths, beam_size=3, nbest=1, decode_max_len=0, softmax_smoothing=1.25, length_penalty=0.6, eos_penalty=1.0) # repetition_penalty=1.0, decode_min_len=0, temperature=1.0 used only for llm + print(prof.key_averages().table(sort_by="cuda_time_total")) + prof.export_chrome_trace(f"firered_asr_profile_{batch}_{ATTENTION_BACKEND}.json") + else: + hyps = model.model.transcribe(feats, lengths, beam_size=3, nbest=1, decode_max_len=0, softmax_smoothing=1.25, length_penalty=0.6, eos_penalty=1.0) # repetition_penalty=1.0, decode_min_len=0, temperature=1.0 used only for llm + total_time += time.time() - start + elapsed = time.time() - start + + rtf = elapsed / total_dur if total_dur > 0 else 0 + for uttid, wav, hyp in zip(batch_uttid, batch_wav_path, hyps): + hyp = hyp[0] # only return 1-best + hyp_ids = [int(id) for id in hyp["yseq"].cpu()] + text = model.tokenizer.detokenize(hyp_ids) + results.append({"uttid": uttid, "text": text, "wav": wav, + "rtf": f"{rtf:.4f}"}) + rtf_list.append(rtf) + + avg_latency = total_time / trials + rps = batch / avg_latency + # Only print last result for debug purpose + print("Only print last run results for debug purpose...") + for res in results[-batch:]: + print(res) + avg_rtf = sum(rtf_list) / len(rtf_list) + print(f"Finished benchmark test for batch size: {len(batch_wav_path)}, average latency: {avg_latency:.3f}s | RPS: {rps:.2f}, avg RTF: {avg_rtf:.3f}") + + return rps, avg_latency, avg_rtf, avg_audio_dur_per_sample, results[-batch:] + +def benchmark(model, audio_dir, batch, warmpup=2, trials=10, enable_profile=False): + # Get list of .wav files (case-insensitive) + batch_wav_path = [] + + # Collect file paths and durations + file_durations = [] + input_wav_path_list = [os.path.join(audio_dir, f) + for f in os.listdir(audio_dir) + if f.lower().endswith('.wav')] + + for file_path in input_wav_path_list: + try: + y, sr = librosa.load(file_path, sr=None) # keep original sampling rate + duration = len(y) / sr + file_durations.append((file_path, duration)) + except Exception as e: + print(f"Error processing {file_path}: {e}") + + # Sort by duration (longest first) + file_durations.sort(key=lambda x: x[1], reverse=True) + # Optional: print all audio file with duration + for i, (fp, dur) in enumerate(file_durations, start=1): + print(f"{i}. {fp} - {dur:.2f} sec") + + dataset_size = len(input_wav_path_list) + # Loop through data in batches + benchmark_results = [] + e2e_start = time.time() + for start in range(0, dataset_size - dataset_size % batch, batch): + batch_wav_path = [path for path, _ in file_durations[start:start + batch]] + print(f"Processing {batch} batched data from index {start} to {start + batch-1}") + rps, avg_latency, avg_rtf, avg_audio_dur_per_sample, model_results = run(model, batch_wav_path, warmpup, trials, enable_profile, offset=start) + benchmark_results.append((batch, avg_audio_dur_per_sample, avg_latency, rps, avg_rtf, model_results)) + + # Process remaining data if any + remainder = dataset_size % batch + if remainder: + start+=batch + last_batch_wav_path = [path for path, _ in file_durations[-remainder:]] + print(f"Processing {remainder} remaining data : {last_batch_wav_path}") + rps, avg_latency, avg_rtf, avg_audio_dur_per_sample, model_results = run(model, last_batch_wav_path, warmpup, trials, enable_profile, offset=start) + benchmark_results.append((remainder, avg_audio_dur_per_sample, avg_latency, rps, avg_rtf, model_results)) + e2e_duration = time.time() - e2e_start + + return benchmark_results, e2e_duration + +if __name__ == "__main__": + parser = argparse.ArgumentParser(prog='Benchmark scripts for FireRedASR', usage='%(prog)s [options]') + parser.add_argument('-b', '--batch_sizes', type=int, nargs='+', default=1, help='List of batch sizes for performance evaluation') + parser.add_argument('-m', '--model_path', type=str, default="pretrained_models/FireRedASR-AED-L", help='Path to model directory') + parser.add_argument('-a', '--audio_dir', type=str, default='examples/wav', help="Path to input audio directory") + parser.add_argument('-d', '--device', type=str, default='cuda', help="Target inference device") + parser.add_argument('-p', '--profile', action='store_true', help='Enable torch profiler') + args = parser.parse_args() + audio_dir = args.audio_dir + model_path = args.model_path + device = args.device + enable_profile = args.profile + batch_sizes = args.batch_sizes # [1, 4, 8, 16, 32, 64, 128, 256] + model = load_model(model_path) + + if enable_profile: + benchmark_results, e2e_duration = benchmark(model, audio_dir, batch=1, enable_profile=enable_profile) + else: + for batch in batch_sizes: + print(f"*************************** batch size {batch} ***************************") + benchmark_results, e2e_duration = benchmark(model, audio_dir, batch=batch, enable_profile=enable_profile) + + print(f"\nbatch size: {batch}, e2e latency: {e2e_duration} s") + save_results = [] + save_path = f"ATTENTION_BACKEND_{ATTENTION_BACKEND}_bs_{batch}_output.json" + for res in benchmark_results: + print(res[5]) + save_results+=res[5] + for res in benchmark_results: + print(f"batch size: {res[0]}, avg audio duration per sample: {res[1]:.3f} s, avg inference latency {res[2]:.3f} s | RPS: {res[3]:.2f}, avg RTF: {res[4]:.3f}") + with open(save_path, "w", encoding="utf-8") as final: + json.dump(save_results, + final, + indent=2, + ensure_ascii=False, # Keep non-ASCII characters intact + default=lambda x: list(x) if isinstance(x, tuple) else str(x) + ) + print(f"Performance results written to {save_path}") diff --git a/fireredasr/models/fireredasr.py b/fireredasr/models/fireredasr.py index 9cb4e33..a9548a9 100644 --- a/fireredasr/models/fireredasr.py +++ b/fireredasr/models/fireredasr.py @@ -107,7 +107,7 @@ def transcribe(self, batch_uttid, batch_wav_path, args={}): def load_fireredasr_aed_model(model_path): - package = torch.load(model_path, map_location=lambda storage, loc: storage) + package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False) print("model args:", package["args"]) model = FireRedAsrAed.from_args(package["args"]) model.load_state_dict(package["model_state_dict"], strict=True) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py old mode 100644 new mode 100755 index 2088b08..896021f --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Dict import torch @@ -5,6 +6,15 @@ import torch.nn.functional as F from torch import Tensor +try: + import xformers.ops as xops + xformers_available = True +except Exception as e: + xformers_available = False + print("xformers is not available because: %s", str(e)) + +ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "SDPA") # Option: "NATIVE", "SDPA", "XFORMERS" +print("ATTENTION_BACKEND: ", ATTENTION_BACKEND) class TransformerDecoder(nn.Module): def __init__( @@ -38,6 +48,12 @@ def __init__( def batch_beam_search(self, encoder_outputs, src_masks, beam_size=1, nbest=1, decode_max_len=0, softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): + if ATTENTION_BACKEND.upper() == "XFORMERS": + for dec_layer in self.layer_stack: + dec_layer.cross_attn.attention.reset_attn_bias() + + for dec_layer in self.layer_stack: + dec_layer.cross_attn.clear_states() B = beam_size N, Ti, H = encoder_outputs.size() device = encoder_outputs.device @@ -137,6 +153,7 @@ def batch_beam_search(self, encoder_outputs, src_masks, } n_nbest_hyps.append(new_hyp) nbest_hyps.append(n_nbest_hyps) + return nbest_hyps def ignored_target_position_is_0(self, padded_targets, ignore_id): @@ -174,10 +191,10 @@ class DecoderLayer(nn.Module): def __init__(self, d_model, n_head, dropout): super().__init__() self.self_attn_norm = nn.LayerNorm(d_model) - self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) + self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout, is_cross=False) self.cross_attn_norm = nn.LayerNorm(d_model) - self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) + self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout, is_cross=True) self.mlp_norm = nn.LayerNorm(d_model) self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) @@ -212,7 +229,7 @@ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, class DecoderMultiHeadAttention(nn.Module): - def __init__(self, d_model, n_head, dropout=0.1): + def __init__(self, d_model, n_head, dropout=0.1, is_cross = False): super().__init__() self.d_model = d_model self.n_head = n_head @@ -222,17 +239,44 @@ def __init__(self, d_model, n_head, dropout=0.1): self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * self.d_k) - self.attention = DecoderScaledDotProductAttention( - temperature=self.d_k ** 0.5) + # Native multi-head attention + if ATTENTION_BACKEND.upper() == "NATIVE": + self.attention = DecoderScaledDotProductAttention(temperature=self.d_k ** 0.5) + # Torch SDPA + elif ATTENTION_BACKEND.upper() == "SDPA": + self.attention = DecoderTorchSDPA(temperature=self.d_k ** 0.5) + # XFormers attention + elif ATTENTION_BACKEND.upper() == "XFORMERS": + if not xformers_available: + print("ATTENTION_BACKEND='XFORMERS' selected, but the xformers package is not available. Please install xformers") + exit(1) + self.attention = DecoderXFormersAttention(self.n_head, self.d_k, self.d_model, temperature=self.d_k ** 0.5, is_cross=is_cross) + else: + print("Unsupported attention backend: ", ATTENTION_BACKEND) + exit(1) self.fc = nn.Linear(n_head * self.d_k, d_model) self.dropout = nn.Dropout(dropout) + self.is_cross = is_cross + self.kv_proj = None - def forward(self, q, k, v, mask=None): + def clear_states(self): + self.kv_proj = None + + def forward(self, q, k, v, mask=None, cross_kv_cache=None): bs = q.size(0) q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) - k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) - v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) + if self.is_cross: + # cross attention reuse the same k,v projection throughout decoding phase + if self.kv_proj is None: + self.kv_proj = ( + self.w_ks(k).view(bs, -1, self.n_head, self.d_k), + self.w_vs(v).view(bs, -1, self.n_head, self.d_k) + ) + k,v = self.kv_proj + else: + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -249,6 +293,7 @@ def forward(self, q, k, v, mask=None): return output +# Native SDPA class DecoderScaledDotProductAttention(nn.Module): def __init__(self, temperature): super().__init__() @@ -264,8 +309,116 @@ def forward(self, q, k, v, mask=None): else: attn = torch.softmax(attn, dim=-1) output = torch.matmul(attn, v) + + return output + + +# Torch SDPA +class DecoderTorchSDPA(nn.Module): + def __init__(self, temperature): + super().__init__() + self.temperature = temperature + self.scale = 1 / self.temperature + + def forward(self, q, k, v, mask=None): + bs = q.size(0) + output = None + if bs == 1: + output = F.scaled_dot_product_attention( + q, k, v, + scale = self.scale + ) + else: + if mask is not None: + if mask.dtype != torch.bool: + mask = mask.eq(1) + output = F.scaled_dot_product_attention( + q, k, v, + attn_mask = mask, + scale = self.scale + ) + return output +class XFormersAttentionMetadata: + """Metadata for XFormers Attention backend """ + def __init__(self, is_cross): + self.is_cross = is_cross + self.attn_bias = None + + def set_cross_attn_bias(self, mask, bs, q_len, k_len, n_head, dtype, device): + if self.is_cross: + mask = mask.to(torch.bool) + + # If mask only has 1 in q_len dimension, expand it + if mask.size(2) == 1 and q_len > 1: + mask = mask.expand(bs, 1, q_len, k_len) + + # Expand mask for all heads + mask = mask.expand(bs, n_head, q_len, k_len) \ + .reshape(bs * n_head, q_len, k_len) + + # Alignment requirement for xformers: pad allocation to multiple of 8 + pad_k = ((k_len + 7) // 8) * 8 + pad_q = ((q_len + 7) // 8) * 8 + + bias_full = torch.zeros(bs * n_head, pad_q, pad_k, + dtype=dtype, device=device) + + bias_full[:, :q_len, :k_len].masked_fill_(~mask, float("-inf")) + + # Slice down to actual shape but keep aligned backing storage + self.attn_bias = bias_full[:, :q_len, :k_len] + + def get_attn_bias(self): + return self.attn_bias + + def reset_attn_bias(self): + self.attn_bias = None + +# xFormers Attention +class DecoderXFormersAttention(nn.Module): + def __init__(self, n_head, d_k, d_model, temperature, is_cross): + super().__init__() + self.temperature = temperature + self.n_head = n_head + self.d_k = d_k + self.d_model = d_model + self.attention_metadata = XFormersAttentionMetadata(is_cross) + self.is_cross = is_cross + + def reset_attn_bias(self): + self.attention_metadata.reset_attn_bias() + + def forward(self, q, k, v, mask=None): + original_query = q + bs = q.size(0) + # Save lengths + q_len = q.size(2) # seq_len_q + k_len = k.size(2) # seq_len_k + dtype = q.dtype + + q = q.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) + k = k.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) + v = v.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) + + output = None + if bs == 1: + output = xops.memory_efficient_attention(q, k, v) + else: + attn_bias = None + # --- Cross-attention / padding mask --- + if self.is_cross and mask is not None: + if self.attention_metadata.get_attn_bias() == None: + self.attention_metadata.set_cross_attn_bias(mask, bs, q_len, k_len, self.n_head, q.dtype, q.device) + attn_bias = self.attention_metadata.get_attn_bias() + + # --- Run memory-efficient attention --- + output = xops.memory_efficient_attention(q, k, v, + attn_bias=attn_bias) + + # reshape back to (bs, seq_len, d_model) + return output.to(dtype) class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 index 40afd7e..b5ded77 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ peft>=0.13.2 sentencepiece torch>=2.0.0 transformers>=4.46.3 +librosa