From 974386f8120d3f670e57f68930d77e2e70bb2a3f Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Wed, 29 Oct 2025 15:54:31 +0800 Subject: [PATCH 01/13] Add FireRedASR optimization on ROCm Signed-off-by: Xiake Sun --- README.md | 25 ++++ docker/Dockerfile.rocm | 28 +++++ examples/benchmark_firered_asr.py | 115 ++++++++++++++++++ fireredasr/models/fireredasr.py | 2 +- .../models/module/transformer_decoder.py | 70 ++++++++++- requirements.txt | 1 + 6 files changed, 238 insertions(+), 3 deletions(-) mode change 100644 => 100755 README.md create mode 100755 docker/Dockerfile.rocm create mode 100755 examples/benchmark_firered_asr.py mode change 100644 => 100755 fireredasr/models/module/transformer_decoder.py mode change 100644 => 100755 requirements.txt diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 4d06aaf..2c81468 --- a/README.md +++ b/README.md @@ -140,6 +140,31 @@ print(results) - FireRedASR-AED supports audio input up to 60s. Input longer than 60s may cause hallucination issues, and input exceeding 200s will trigger positional encoding errors. - FireRedASR-LLM supports audio input up to 30s. The behavior for longer input is currently unknown. +## FireRedASR-AED Optimization with ROCm +1. Build docker image with `docker/Dockerfile.rocm` to setup environemt +``` +docker build --network=host -f docker/Dockerfile.rocm -t rocm/firered-asr-opt +``` + +2. Launch docker container +``` +docker run -it --ipc=host --network=host --privileged --security-opt seccomp=unconfined --cap-add=CAP_SYS_ADMIN --cap-add=SYS_PTRACE --device=/dev/kfd --device=/dev/dri --device=/dev/mem rocm/firered-asr-opt +``` + +3. Run performance test with native MHA (baseline) +```python +ATTENTION_BACKEND="NATIVE" python example/benchmark_firered_asr.py +``` + +4. Run performance test with MHA using torch SDPA +```python +ATTENTION_BACKEND="SDPA" python example/benchmark_firered_asr.py +``` + +5. Run performance test with MHA using xFormers +```python +ATTENTION_BACKEND="XFORMERS" python example/benchmark_firered_asr.py +``` ## Acknowledgements Thanks to the following open-source works: diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm new file mode 100755 index 0000000..a099804 --- /dev/null +++ b/docker/Dockerfile.rocm @@ -0,0 +1,28 @@ +ARG BASE_IMAGE=rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0 +FROM ${BASE_IMAGE} AS base + +ARG PYTORCH_ROCM_ARCH=gfx942 +ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} +ENV HIP_ARCHITECTURE=${PYTORCH_ROCM_ARCH} +ENV ROCM_PATH=/opt/rocm +ENV XFORMERS_CK_FLASH_ATTN=1 +ARG PYTHON_VERSION=3.12 +ENV DEBIAN_FRONTEND=noninteractive + +WORKDIR /root + +RUN set -ex && usermod -a -G video $(whoami) + +RUN python3 -m pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython + +RUN git clone https://github.com/FireRedTeam/FireRedASR /root/FireRedASR \ + && cd /root/FireRedASR \ + && python3 -m pip install -r requirements.txt + +RUN git clone https://github.com/ROCm/xformers.git /root/xformers \ + && cd /root/xformers \ + && git checkout 5f0419a \ + && git submodule update --init --recursive \ + && PYTORCH_ROCM_ARCH=$PYTORCH_ROCM_ARCH HIP_ARCHITECTURE=$HIP_ARCHITECTURE XFORMERS_CK_FLASH_ATTN=$XFORMERS_CK_FLASH_ATTN python3 setup.py install + +WORKDIR /root/FireRedASR \ No newline at end of file diff --git a/examples/benchmark_firered_asr.py b/examples/benchmark_firered_asr.py new file mode 100755 index 0000000..63385ea --- /dev/null +++ b/examples/benchmark_firered_asr.py @@ -0,0 +1,115 @@ +import os +import time +import torch +import numpy as np +from tqdm import tqdm + +import librosa +import soundfile as sf +import argparse +import torch + +from fireredasr.models.fireredasr import FireRedAsr + +from torch.profiler import profile as torch_profiler +from torch.profiler import ProfilerActivity, record_function + + +ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "XFORMERS") # Option: "NATIVE", "SDPA", "XFORMERS" + + +def load_model(model_path="pretrained_models/FireRedASR-AED-L"): + print("==========Load model:========") + model = FireRedAsr.from_pretrained("aed", model_path) + model.model.half() + model.model.cuda() + model.model.eval() + + return model + +def load_audio(wav_path): + print("==========load audio:=========") + audio, sr = sf.read(wav_path,dtype=np.float32) + print(len(audio), audio.dtype) + if sr != 16000: + audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) + return audio + +def benchmark(model, wav_path, batch, warmpup=2, trials=10, enable_profile=False): + batch_wav_path = [wav_path] * batch + batch_uttid = list(range(batch)) + results = None + total_dur = None + + preprocess_start = time.time() + feats, lengths, durs = model.feat_extractor(batch_wav_path) + feats = feats.half() + feats, lengths = feats.cuda(), lengths.cuda() + preprocess_dur = time.time() - preprocess_start + print(f"preprocess_dur: {preprocess_dur:.3f} s") + total_dur = sum(durs) + + # Warmup + print("==========warmup========") + for _ in range(warmpup): + with torch.no_grad(): + _ = model.model.transcribe(feats, lengths) + + # Benchmark + print("==========start benchmark========") + total_time = 0 + results = [] + rtf_list = [] + if enable_profile: + warmup=1 + trials=1 + for _ in tqdm(range(trials)): + start = time.time() + with torch.no_grad(): + if enable_profile: + with torch_profiler(activities=[ + ProfilerActivity.CPU, + ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True, + profile_memory=False) as prof: + #with record_function("model.model.transcribe"): + hyps = model.model.transcribe(feats, lengths) + print(prof.key_averages().table(sort_by="cuda_time_total")) + prof.export_chrome_trace(f"firered_asr_profile_{batch}_{ATTENTION_BACKEND}.json") + else: + hyps = model.model.transcribe(feats, lengths) + total_time += time.time() - start + elapsed = time.time() - start + + rtf = elapsed / total_dur if total_dur > 0 else 0 + for uttid, wav, hyp in zip(batch_uttid, batch_wav_path, hyps): + hyp = hyp[0] # only return 1-best + hyp_ids = [int(id) for id in hyp["yseq"].cpu()] + text = model.tokenizer.detokenize(hyp_ids) + results.append({"uttid": uttid, "text": text, "wav": wav, + "rtf": f"{rtf:.4f}"}) + rtf_list.append(rtf) + + avg_latency = total_time / trials + rps = batch / avg_latency + for res in results: + print(res) + avg_rtf = sum(rtf_list) / len(rtf_list) + return rps, avg_rtf + +if __name__ == "__main__": + audio_path = "examples/wav/TEST_MEETING_T0000000001_S00000.wav" + device = "cuda" if torch.cuda.is_available() else "cpu" + model_path = "pretrained_models/FireRedASR-AED-L" + enable_profile = False + batch_sizes = [1] + model = load_model(model_path) + + if enable_profile: + rps, avg_rtf = benchmark(model, audio_path, batch=1, enable_profile=True) + else: + for batch in batch_sizes: + print(f"=============== batch size {batch} ==========================") + rps, avg_rtf = benchmark(model, audio_path, batch=batch) + print(f"batch size: {batch}, average latency: {1.0/rps:.3f}s | RPS: {rps:.2f}, avg RTF: {avg_rtf:.3f}") diff --git a/fireredasr/models/fireredasr.py b/fireredasr/models/fireredasr.py index 9cb4e33..a9548a9 100644 --- a/fireredasr/models/fireredasr.py +++ b/fireredasr/models/fireredasr.py @@ -107,7 +107,7 @@ def transcribe(self, batch_uttid, batch_wav_path, args={}): def load_fireredasr_aed_model(model_path): - package = torch.load(model_path, map_location=lambda storage, loc: storage) + package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False) print("model args:", package["args"]) model = FireRedAsrAed.from_args(package["args"]) model.load_state_dict(package["model_state_dict"], strict=True) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py old mode 100644 new mode 100755 index 2088b08..b6238fc --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -4,7 +4,20 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor +import math +import os +import torch +import torch.nn as nn +try: + import xformers.ops as xops + xformers_available = True +except Exception as e: + xformers_available = False + print("xformers is not available because: %s", str(e)) + +ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "SDPA") # Option: "NATIVE", "SDPA", "XFORMERS" +print("ATTENTION_BACKEND: ", ATTENTION_BACKEND) class TransformerDecoder(nn.Module): def __init__( @@ -222,8 +235,21 @@ def __init__(self, d_model, n_head, dropout=0.1): self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * self.d_k) - self.attention = DecoderScaledDotProductAttention( - temperature=self.d_k ** 0.5) + # Native multi-head attention + if ATTENTION_BACKEND.upper() == "NATIVE": + self.attention = DecoderScaledDotProductAttention(temperature=self.d_k ** 0.5) + # Torch SDPA + elif ATTENTION_BACKEND.upper() == "SDPA": + self.attention = DecoderTorchSDPA(temperature=self.d_k ** 0.5) + # XFormers attention + elif ATTENTION_BACKEND.upper() == "XFORMERS": + if not xformers_available: + print("ATTENTION_BACKEND='XFORMERS' selected, but the xformers package is not available. Please install xformers") + exit(1) + self.attention = DecoderXFormersAttention(self.n_head, self.d_k, self.d_model, temperature=self.d_k ** 0.5) + else: + print("Unsupported attention backend: ", ATTENTION_BACKEND) + exit(1) self.fc = nn.Linear(n_head * self.d_k, d_model) self.dropout = nn.Dropout(dropout) @@ -249,6 +275,7 @@ def forward(self, q, k, v, mask=None): return output +# Native SDPA class DecoderScaledDotProductAttention(nn.Module): def __init__(self, temperature): super().__init__() @@ -264,6 +291,45 @@ def forward(self, q, k, v, mask=None): else: attn = torch.softmax(attn, dim=-1) output = torch.matmul(attn, v) + + return output + + +# Torch SDPA +class DecoderTorchSDPA(nn.Module): + def __init__(self, temperature): + super().__init__() + self.temperature = temperature + self.scale = 1 / self.temperature + + def forward(self, q, k, v, mask=None): + output = F.scaled_dot_product_attention( + q, k, v, + scale = self.scale + ) + + return output + +# xFormers Attention +class DecoderXFormersAttention(nn.Module): + def __init__(self, n_head, d_k, d_model, temperature): + super().__init__() + self.temperature = temperature + self.n_head = n_head + self.d_k = d_k + self.d_model = d_model + + def forward(self, q, k, v, mask=None): + bs = q.size(0) + dtype = q.dtype + q = q.reshape(bs * self.n_head, -1, self.d_k).half() + k = k.reshape(bs * self.n_head, -1, self.d_k).half() + v = v.reshape(bs * self.n_head, -1, self.d_k).half() + + output = xops.memory_efficient_attention(q, k, v) + # reshape back to (bs, seq_len, d_model) + output = output.reshape(bs, self.n_head, -1, self.d_k).transpose(1, 2).contiguous().view(bs, -1, self.d_model).to(dtype) + return output diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 index 40afd7e..b5ded77 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ peft>=0.13.2 sentencepiece torch>=2.0.0 transformers>=4.46.3 +librosa From c3673057f79049e4a837bd783556a7c239f20339 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Wed, 29 Oct 2025 08:29:59 +0000 Subject: [PATCH 02/13] Minor update --- README.md | 8 ++++---- fireredasr/models/module/transformer_decoder.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 2c81468..c6970cc 100755 --- a/README.md +++ b/README.md @@ -140,7 +140,7 @@ print(results) - FireRedASR-AED supports audio input up to 60s. Input longer than 60s may cause hallucination issues, and input exceeding 200s will trigger positional encoding errors. - FireRedASR-LLM supports audio input up to 30s. The behavior for longer input is currently unknown. -## FireRedASR-AED Optimization with ROCm +### FireRedASR-AED Optimization with ROCm 1. Build docker image with `docker/Dockerfile.rocm` to setup environemt ``` docker build --network=host -f docker/Dockerfile.rocm -t rocm/firered-asr-opt @@ -153,17 +153,17 @@ docker run -it --ipc=host --network=host --privileged --security-opt seccomp=un 3. Run performance test with native MHA (baseline) ```python -ATTENTION_BACKEND="NATIVE" python example/benchmark_firered_asr.py +ATTENTION_BACKEND="NATIVE" python examples/benchmark_firered_asr.py ``` 4. Run performance test with MHA using torch SDPA ```python -ATTENTION_BACKEND="SDPA" python example/benchmark_firered_asr.py +ATTENTION_BACKEND="SDPA" python examples/benchmark_firered_asr.py ``` 5. Run performance test with MHA using xFormers ```python -ATTENTION_BACKEND="XFORMERS" python example/benchmark_firered_asr.py +ATTENTION_BACKEND="XFORMERS" python examples/benchmark_firered_asr.py ``` ## Acknowledgements diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index b6238fc..e78afd3 100755 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -6,8 +6,6 @@ from torch import Tensor import math import os -import torch -import torch.nn as nn try: import xformers.ops as xops From d6de6bad098db19d83ce6a045111caad7b0088f9 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Mon, 3 Nov 2025 16:47:01 +0800 Subject: [PATCH 03/13] Add attention mask support in batch size > 1 condition for Torch SDPA and xFormers attention --- examples/benchmark_firered_asr.py | 26 ++++++--- .../models/module/transformer_decoder.py | 58 ++++++++++++++++++- 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/examples/benchmark_firered_asr.py b/examples/benchmark_firered_asr.py index 63385ea..f80cb08 100755 --- a/examples/benchmark_firered_asr.py +++ b/examples/benchmark_firered_asr.py @@ -48,6 +48,7 @@ def benchmark(model, wav_path, batch, warmpup=2, trials=10, enable_profile=False preprocess_dur = time.time() - preprocess_start print(f"preprocess_dur: {preprocess_dur:.3f} s") total_dur = sum(durs) + print(f"total input duration: {total_dur:.3f} s") # Warmup print("==========warmup========") @@ -96,20 +97,27 @@ def benchmark(model, wav_path, batch, warmpup=2, trials=10, enable_profile=False for res in results: print(res) avg_rtf = sum(rtf_list) / len(rtf_list) - return rps, avg_rtf + return rps, avg_latency, avg_rtf if __name__ == "__main__": - audio_path = "examples/wav/TEST_MEETING_T0000000001_S00000.wav" - device = "cuda" if torch.cuda.is_available() else "cpu" - model_path = "pretrained_models/FireRedASR-AED-L" - enable_profile = False - batch_sizes = [1] + parser = argparse.ArgumentParser(prog='Benchmark scripts for FireRedASR', usage='%(prog)s [options]') + parser.add_argument('-b', '--batch_sizes', type=int, nargs='+', default=1, help='List of batch sizes for performance evaluation') + parser.add_argument('-m', '--model_path', type=str, default="pretrained_models/FireRedASR-AED-L", help='Path to model directory') + parser.add_argument('-a', '--audio_path', type=str, default='examples/wav/TEST_MEETING_T0000000001_S00000.wav', help="Input audio path") + parser.add_argument('-d', '--device', type=str, default='cuda', help="Target inference device") + parser.add_argument('-p', '--profile', action='store_true', help='Enable torch profiler') + args = parser.parse_args() + audio_path = args.audio_path + model_path = args.model_path + device = args.device + enable_profile = args.profile + batch_sizes = args.batch_sizes # [1, 4, 8, 16, 32] model = load_model(model_path) if enable_profile: - rps, avg_rtf = benchmark(model, audio_path, batch=1, enable_profile=True) + rps, avg_rtf, avg_latency = benchmark(model, audio_path, batch=1, enable_profile=True) else: for batch in batch_sizes: print(f"=============== batch size {batch} ==========================") - rps, avg_rtf = benchmark(model, audio_path, batch=batch) - print(f"batch size: {batch}, average latency: {1.0/rps:.3f}s | RPS: {rps:.2f}, avg RTF: {avg_rtf:.3f}") + rps, avg_latency, avg_rtf = benchmark(model, audio_path, batch=batch) + print(f"batch size: {batch}, average latency: {avg_latency:.3f}s | RPS: {rps:.2f}, avg RTF: {avg_rtf:.3f}") diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index e78afd3..959e95e 100755 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -301,13 +301,26 @@ def __init__(self, temperature): self.scale = 1 / self.temperature def forward(self, q, k, v, mask=None): - output = F.scaled_dot_product_attention( + bs = q.size(0) + output = None + if bs == 1: + output = F.scaled_dot_product_attention( q, k, v, scale = self.scale ) + else: + if mask is not None: + if mask.dtype != torch.bool: + mask = mask.eq(1) + output = F.scaled_dot_product_attention( + q, k, v, + attn_mask = mask, + scale = self.scale + ) return output + # xFormers Attention class DecoderXFormersAttention(nn.Module): def __init__(self, n_head, d_k, d_model, temperature): @@ -319,12 +332,53 @@ def __init__(self, n_head, d_k, d_model, temperature): def forward(self, q, k, v, mask=None): bs = q.size(0) + # Save lengths + q_len = q.size(2) # seq_len_q + k_len = k.size(2) # seq_len_k dtype = q.dtype + q = q.reshape(bs * self.n_head, -1, self.d_k).half() k = k.reshape(bs * self.n_head, -1, self.d_k).half() v = v.reshape(bs * self.n_head, -1, self.d_k).half() - output = xops.memory_efficient_attention(q, k, v) + output = None + if bs == 1: + output = xops.memory_efficient_attention(q, k, v) + else: + attn_bias = None + # --- AUTO-DETECT causal self-attention --- + # q and k are the same tensor object in memory when this is pure self-attn + if q_len == k_len and q.data_ptr() == k.data_ptr(): + attn_bias = xops.LowerTriangularMask() + + # --- Cross-attention / padding mask --- + elif mask is not None: + mask = mask.to(torch.bool) + + # If mask only has 1 in q_len dimension, expand it + if mask.size(2) == 1 and q_len > 1: + mask = mask.expand(bs, 1, q_len, k_len) + + # Expand mask for all heads + mask = mask.expand(bs, self.n_head, q_len, k_len) \ + .reshape(bs * self.n_head, q_len, k_len) + + # Alignment requirement for xformers: pad allocation to multiple of 8 + pad_k = ((k_len + 7) // 8) * 8 + pad_q = ((q_len + 7) // 8) * 8 + + bias_full = torch.zeros(bs * self.n_head, pad_q, pad_k, + dtype=q.dtype, device=q.device) + + bias_full[:, :q_len, :k_len].masked_fill_(~mask, float("-inf")) + + # Slice down to actual shape but keep aligned backing storage + attn_bias = bias_full[:, :q_len, :k_len] + + # --- Run memory-efficient attention --- + output = xops.memory_efficient_attention(q, k, v, + attn_bias=attn_bias if attn_bias is not None else None) + # reshape back to (bs, seq_len, d_model) output = output.reshape(bs, self.n_head, -1, self.d_k).transpose(1, 2).contiguous().view(bs, -1, self.d_model).to(dtype) From e2103183e6093f8ffd3e9b89392cb699fbef2629 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Mon, 3 Nov 2025 15:19:22 +0000 Subject: [PATCH 04/13] Refactor xFormers bias_attn handeling with BlockDiagonalMask for batch size > 1 case --- examples/benchmark_firered_asr.py | 27 ++++--- .../models/module/transformer_decoder.py | 79 +++++++++---------- 2 files changed, 51 insertions(+), 55 deletions(-) diff --git a/examples/benchmark_firered_asr.py b/examples/benchmark_firered_asr.py index f80cb08..7ae7cd1 100755 --- a/examples/benchmark_firered_asr.py +++ b/examples/benchmark_firered_asr.py @@ -46,9 +46,9 @@ def benchmark(model, wav_path, batch, warmpup=2, trials=10, enable_profile=False feats = feats.half() feats, lengths = feats.cuda(), lengths.cuda() preprocess_dur = time.time() - preprocess_start - print(f"preprocess_dur: {preprocess_dur:.3f} s") + print(f"preprocess duration: {preprocess_dur:.3f} s") total_dur = sum(durs) - print(f"total input duration: {total_dur:.3f} s") + print(f"total input audio duration: {total_dur:.3f} s") # Warmup print("==========warmup========") @@ -68,13 +68,12 @@ def benchmark(model, wav_path, batch, warmpup=2, trials=10, enable_profile=False start = time.time() with torch.no_grad(): if enable_profile: - with torch_profiler(activities=[ - ProfilerActivity.CPU, - ProfilerActivity.CUDA], - record_shapes=True, - with_stack=True, - profile_memory=False) as prof: - #with record_function("model.model.transcribe"): + with torch_profiler( + activities=[ProfilerActivity.CPU, + ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True, + profile_memory=False) as prof: hyps = model.model.transcribe(feats, lengths) print(prof.key_averages().table(sort_by="cuda_time_total")) prof.export_chrome_trace(f"firered_asr_profile_{batch}_{ATTENTION_BACKEND}.json") @@ -94,8 +93,10 @@ def benchmark(model, wav_path, batch, warmpup=2, trials=10, enable_profile=False avg_latency = total_time / trials rps = batch / avg_latency - for res in results: - print(res) + # Only print first result for debug purpose + print("results[0]: ", results[0]) + #for res in results: + #print(res) avg_rtf = sum(rtf_list) / len(rtf_list) return rps, avg_latency, avg_rtf @@ -111,13 +112,13 @@ def benchmark(model, wav_path, batch, warmpup=2, trials=10, enable_profile=False model_path = args.model_path device = args.device enable_profile = args.profile - batch_sizes = args.batch_sizes # [1, 4, 8, 16, 32] + batch_sizes = args.batch_sizes # [1, 4, 8, 16, 32, 64, 128, 256] model = load_model(model_path) if enable_profile: rps, avg_rtf, avg_latency = benchmark(model, audio_path, batch=1, enable_profile=True) else: for batch in batch_sizes: - print(f"=============== batch size {batch} ==========================") + print(f"*************************** batch size {batch} ***************************") rps, avg_latency, avg_rtf = benchmark(model, audio_path, batch=batch) print(f"batch size: {batch}, average latency: {avg_latency:.3f}s | RPS: {rps:.2f}, avg RTF: {avg_rtf:.3f}") diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 959e95e..2a047b9 100755 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -9,6 +9,7 @@ try: import xformers.ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask xformers_available = True except Exception as e: xformers_available = False @@ -257,9 +258,6 @@ def forward(self, q, k, v, mask=None): q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) if mask is not None: mask = mask.unsqueeze(1) @@ -281,6 +279,9 @@ def __init__(self, temperature): self.INF = float("inf") def forward(self, q, k, v, mask=None): + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature if mask is not None: mask = mask.eq(0) @@ -302,6 +303,9 @@ def __init__(self, temperature): def forward(self, q, k, v, mask=None): bs = q.size(0) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) output = None if bs == 1: output = F.scaled_dot_product_attention( @@ -329,60 +333,51 @@ def __init__(self, n_head, d_k, d_model, temperature): self.n_head = n_head self.d_k = d_k self.d_model = d_model + self.scale = 1.0 / temperature def forward(self, q, k, v, mask=None): + original_query = q bs = q.size(0) - # Save lengths - q_len = q.size(2) # seq_len_q - k_len = k.size(2) # seq_len_k - dtype = q.dtype + q_len = q.size(1) # seq_len_q + k_len = k.size(1) # seq_len_k - q = q.reshape(bs * self.n_head, -1, self.d_k).half() - k = k.reshape(bs * self.n_head, -1, self.d_k).half() - v = v.reshape(bs * self.n_head, -1, self.d_k).half() + # Reshape and add batch dimension for input of xformers memory_efficient_attention_forward + q = q.reshape(-1, self.n_head, self.d_k).unsqueeze(0).half() + k = k.reshape(-1, self.n_head, self.d_k).unsqueeze(0).half() + v = v.reshape(-1, self.n_head, self.d_k).unsqueeze(0).half() output = None if bs == 1: - output = xops.memory_efficient_attention(q, k, v) + output = xops.memory_efficient_attention_forward(q, + k, + v, + scale=self.scale, + op=xops.fmha.ck.FwOp) else: attn_bias = None - # --- AUTO-DETECT causal self-attention --- + # --- Detect if it is causal self-attention --- # q and k are the same tensor object in memory when this is pure self-attn if q_len == k_len and q.data_ptr() == k.data_ptr(): attn_bias = xops.LowerTriangularMask() # --- Cross-attention / padding mask --- elif mask is not None: - mask = mask.to(torch.bool) - - # If mask only has 1 in q_len dimension, expand it - if mask.size(2) == 1 and q_len > 1: - mask = mask.expand(bs, 1, q_len, k_len) - - # Expand mask for all heads - mask = mask.expand(bs, self.n_head, q_len, k_len) \ - .reshape(bs * self.n_head, q_len, k_len) - - # Alignment requirement for xformers: pad allocation to multiple of 8 - pad_k = ((k_len + 7) // 8) * 8 - pad_q = ((q_len + 7) // 8) * 8 - - bias_full = torch.zeros(bs * self.n_head, pad_q, pad_k, - dtype=q.dtype, device=q.device) - - bias_full[:, :q_len, :k_len].masked_fill_(~mask, float("-inf")) - - # Slice down to actual shape but keep aligned backing storage - attn_bias = bias_full[:, :q_len, :k_len] - - # --- Run memory-efficient attention --- - output = xops.memory_efficient_attention(q, k, v, - attn_bias=attn_bias if attn_bias is not None else None) - - # reshape back to (bs, seq_len, d_model) - output = output.reshape(bs, self.n_head, -1, self.d_k).transpose(1, 2).contiguous().view(bs, -1, self.d_model).to(dtype) - - return output + attn_bias = BlockDiagonalMask.from_seqlens([q_len] * bs, [k_len] * bs, device=q.device) + # print("==========================================") + # print("attn_bias: ", attn_bias) + # print("query.shape: ", query.shape) + # print("key.shape: ", key.shape) + # print("value.shape: ", value.shape) + # print("\n") + output = xops.memory_efficient_attention_forward( + q, + k, + v, + attn_bias=attn_bias, + scale=self.scale, + op=xops.fmha.ck.FwOp) + + return output.view_as(original_query) class PositionwiseFeedForward(nn.Module): From 4dbda0b2cf4eb4c06f5a8ca328ebf21243b2fac1 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Mon, 3 Nov 2025 15:50:15 +0000 Subject: [PATCH 05/13] Fix dtype --- fireredasr/models/module/transformer_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 2a047b9..b855a7c 100755 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -377,7 +377,7 @@ def forward(self, q, k, v, mask=None): scale=self.scale, op=xops.fmha.ck.FwOp) - return output.view_as(original_query) + return output.view_as(original_query).to(original_query.dtype) class PositionwiseFeedForward(nn.Module): From aa3f757092838e2bb08561d5b52441c780b76286 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Tue, 4 Nov 2025 07:11:30 +0000 Subject: [PATCH 06/13] Fix attn_bias creation in cross attention mask for audio input with variable length --- fireredasr/models/module/transformer_decoder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index b855a7c..4ae8f75 100755 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -9,7 +9,7 @@ try: import xformers.ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask + from xformers.ops.fmha.attn_bias import BlockDiagonalMask, BlockDiagonalCausalMask xformers_available = True except Exception as e: xformers_available = False @@ -358,11 +358,12 @@ def forward(self, q, k, v, mask=None): # --- Detect if it is causal self-attention --- # q and k are the same tensor object in memory when this is pure self-attn if q_len == k_len and q.data_ptr() == k.data_ptr(): - attn_bias = xops.LowerTriangularMask() + attn_bias = BlockDiagonalCausalMask.from_seqlens([q_len] * bs, device=q.device) # --- Cross-attention / padding mask --- elif mask is not None: - attn_bias = BlockDiagonalMask.from_seqlens([q_len] * bs, [k_len] * bs, device=q.device) + encoder_seq_lens = [mask[i].flatten().sum() for i in range(mask.shape[0])] + attn_bias = BlockDiagonalMask.from_seqlens([q_len] * bs, encoder_seq_lens, device=q.device) # print("==========================================") # print("attn_bias: ", attn_bias) # print("query.shape: ", query.shape) From 7efe26f84a8a81693936faa4e805ce61346ee742 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Thu, 6 Nov 2025 09:41:49 +0000 Subject: [PATCH 07/13] Refactor xFormers backend --- .../models/module/transformer_decoder.py | 114 +++++++++++------- 1 file changed, 71 insertions(+), 43 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 4ae8f75..bd0c0bf 100755 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -9,7 +9,6 @@ try: import xformers.ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask, BlockDiagonalCausalMask xformers_available = True except Exception as e: xformers_available = False @@ -224,7 +223,7 @@ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, class DecoderMultiHeadAttention(nn.Module): - def __init__(self, d_model, n_head, dropout=0.1): + def __init__(self, d_model, n_head, dropout=0.1, attention_type=None): super().__init__() self.d_model = d_model self.n_head = n_head @@ -245,7 +244,7 @@ def __init__(self, d_model, n_head, dropout=0.1): if not xformers_available: print("ATTENTION_BACKEND='XFORMERS' selected, but the xformers package is not available. Please install xformers") exit(1) - self.attention = DecoderXFormersAttention(self.n_head, self.d_k, self.d_model, temperature=self.d_k ** 0.5) + self.attention = DecoderXFormersAttention(self.n_head, self.d_k, self.d_model, temperature=self.d_k ** 0.5, attention_type=attention_type) else: print("Unsupported attention backend: ", ATTENTION_BACKEND) exit(1) @@ -258,6 +257,9 @@ def forward(self, q, k, v, mask=None): q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) if mask is not None: mask = mask.unsqueeze(1) @@ -279,9 +281,6 @@ def __init__(self, temperature): self.INF = float("inf") def forward(self, q, k, v, mask=None): - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature if mask is not None: mask = mask.eq(0) @@ -303,9 +302,6 @@ def __init__(self, temperature): def forward(self, q, k, v, mask=None): bs = q.size(0) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) output = None if bs == 1: output = F.scaled_dot_product_attention( @@ -325,60 +321,92 @@ def forward(self, q, k, v, mask=None): return output +class XFormersAttentionMetadata: + """Metadata for XFormers Attention backend """ + def __init__(self, attention_type): + self.attention_type = attention_type + self.attn_bias = None + + def set_self_attn_bias(self): + if self.attention_type == "self_attention": + self.attn_bias = xops.LowerTriangularMask() + else: + print("Unknown attention type used, only support `self_attention`") + + def set_cross_attn_bias(self, mask, bs, q_len, k_len): + if self.attention_type == "cross_attention": + mask = mask.to(torch.bool) + + # If mask only has 1 in q_len dimension, expand it + if mask.size(2) == 1 and q_len > 1: + mask = mask.expand(bs, 1, q_len, k_len) + + # Expand mask for all heads + mask = mask.expand(bs, self.n_head, q_len, k_len) \ + .reshape(bs * self.n_head, q_len, k_len) + + # Alignment requirement for xformers: pad allocation to multiple of 8 + pad_k = ((k_len + 7) // 8) * 8 + pad_q = ((q_len + 7) // 8) * 8 + + bias_full = torch.zeros(bs * self.n_head, pad_q, pad_k, + dtype=q.dtype, device=q.device) + + bias_full[:, :q_len, :k_len].masked_fill_(~mask, float("-inf")) + + # Slice down to actual shape but keep aligned backing storage + self.attn_bias = bias_full[:, :q_len, :k_len] + + print("Unknown attention type used, only support `self_attention` and `cross_attention`") + + def get_attn_bias(self): + return self.attn_bias + # xFormers Attention class DecoderXFormersAttention(nn.Module): - def __init__(self, n_head, d_k, d_model, temperature): + def __init__(self, n_head, d_k, d_model, temperature, attention_type): super().__init__() self.temperature = temperature self.n_head = n_head self.d_k = d_k self.d_model = d_model - self.scale = 1.0 / temperature + self.attention_metadata = XFormersAttentionMetadata(attention_type) def forward(self, q, k, v, mask=None): original_query = q bs = q.size(0) - q_len = q.size(1) # seq_len_q - k_len = k.size(1) # seq_len_k + # Save lengths + q_len = q.size(2) # seq_len_q + k_len = k.size(2) # seq_len_k + dtype = q.dtype - # Reshape and add batch dimension for input of xformers memory_efficient_attention_forward - q = q.reshape(-1, self.n_head, self.d_k).unsqueeze(0).half() - k = k.reshape(-1, self.n_head, self.d_k).unsqueeze(0).half() - v = v.reshape(-1, self.n_head, self.d_k).unsqueeze(0).half() + q = q.reshape(bs * self.n_head, -1, self.d_k).to(torch.bfloat16) + k = k.reshape(bs * self.n_head, -1, self.d_k).to(torch.bfloat16) + v = v.reshape(bs * self.n_head, -1, self.d_k).to(torch.bfloat16) output = None if bs == 1: - output = xops.memory_efficient_attention_forward(q, - k, - v, - scale=self.scale, - op=xops.fmha.ck.FwOp) + output = xops.memory_efficient_attention(q, k, v) else: attn_bias = None - # --- Detect if it is causal self-attention --- + # --- AUTO-DETECT causal self-attention --- # q and k are the same tensor object in memory when this is pure self-attn - if q_len == k_len and q.data_ptr() == k.data_ptr(): - attn_bias = BlockDiagonalCausalMask.from_seqlens([q_len] * bs, device=q.device) + if self.attention_metadata.attention_type == "self_attention": + attn_bias = xops.LowerTriangularMask() # --- Cross-attention / padding mask --- - elif mask is not None: - encoder_seq_lens = [mask[i].flatten().sum() for i in range(mask.shape[0])] - attn_bias = BlockDiagonalMask.from_seqlens([q_len] * bs, encoder_seq_lens, device=q.device) - # print("==========================================") - # print("attn_bias: ", attn_bias) - # print("query.shape: ", query.shape) - # print("key.shape: ", key.shape) - # print("value.shape: ", value.shape) - # print("\n") - output = xops.memory_efficient_attention_forward( - q, - k, - v, - attn_bias=attn_bias, - scale=self.scale, - op=xops.fmha.ck.FwOp) - - return output.view_as(original_query).to(original_query.dtype) + #elif mask is not None: + elif self.attention_metadata.attention_type == "cross_attention" and mask is not None: + if self.attention_metadata.get_attn_bias() == None: + self.attention_metadata.set_cross_attn_bias(mask, bs, q_len, k_len) + attn_bias = self.attention_metadata.get_attn_bias() + + # --- Run memory-efficient attention --- + output = xops.memory_efficient_attention(q, k, v, + attn_bias=attn_bias if attn_bias is not None else None) + + # reshape back to (bs, seq_len, d_model) + return output.view_as(original_query).to(dtype) class PositionwiseFeedForward(nn.Module): From 2d6af13b959b88b9efd1ff835603b4388ab11f68 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Thu, 6 Nov 2025 09:42:20 +0000 Subject: [PATCH 08/13] Update benchmark script to load audio from directory --- examples/benchmark_firered_asr.py | 51 +++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/examples/benchmark_firered_asr.py b/examples/benchmark_firered_asr.py index 7ae7cd1..2cfa0ef 100755 --- a/examples/benchmark_firered_asr.py +++ b/examples/benchmark_firered_asr.py @@ -21,7 +21,7 @@ def load_model(model_path="pretrained_models/FireRedASR-AED-L"): print("==========Load model:========") model = FireRedAsr.from_pretrained("aed", model_path) - model.model.half() + model.model.to(torch.bfloat16) model.model.cuda() model.model.eval() @@ -35,15 +35,41 @@ def load_audio(wav_path): audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) return audio -def benchmark(model, wav_path, batch, warmpup=2, trials=10, enable_profile=False): - batch_wav_path = [wav_path] * batch +def benchmark(model, audio_dir, batch, warmpup=2, trials=10, enable_profile=False): + # Get list of .wav files (case-insensitive) + batch_wav_path = [] + + # Collect file paths and durations + file_durations = [] + input_wav_path_list = [os.path.join(audio_dir, f) + for f in os.listdir(audio_dir) + if f.lower().endswith('.wav')] + + for file_path in input_wav_path_list: + try: + y, sr = librosa.load(file_path, sr=None) # keep original sampling rate + duration = len(y) / sr + file_durations.append((file_path, duration)) + except Exception as e: + print(f"Error processing {file_path}: {e}") + + # Sort by duration (longest first) + file_durations.sort(key=lambda x: x[1], reverse=True) + + # Take top N for batch + batch_wav_path = [fp for fp, dur in file_durations[:batch]] + + # Optional: print results + for i, (fp, dur) in enumerate(file_durations[:batch], start=1): + print(f"{i}. {fp} - {dur:.2f} sec") + batch_uttid = list(range(batch)) results = None total_dur = None preprocess_start = time.time() feats, lengths, durs = model.feat_extractor(batch_wav_path) - feats = feats.half() + feats = feats.to(torch.bfloat16) feats, lengths = feats.cuda(), lengths.cuda() preprocess_dur = time.time() - preprocess_start print(f"preprocess duration: {preprocess_dur:.3f} s") @@ -93,10 +119,11 @@ def benchmark(model, wav_path, batch, warmpup=2, trials=10, enable_profile=False avg_latency = total_time / trials rps = batch / avg_latency - # Only print first result for debug purpose - print("results[0]: ", results[0]) - #for res in results: - #print(res) + #Only print last result for debug purpose + #print("results[0]: ", results[0]) + print("Only print last run results for debug purpose...") + for res in results[-batch:]: + print(res) avg_rtf = sum(rtf_list) / len(rtf_list) return rps, avg_latency, avg_rtf @@ -104,11 +131,11 @@ def benchmark(model, wav_path, batch, warmpup=2, trials=10, enable_profile=False parser = argparse.ArgumentParser(prog='Benchmark scripts for FireRedASR', usage='%(prog)s [options]') parser.add_argument('-b', '--batch_sizes', type=int, nargs='+', default=1, help='List of batch sizes for performance evaluation') parser.add_argument('-m', '--model_path', type=str, default="pretrained_models/FireRedASR-AED-L", help='Path to model directory') - parser.add_argument('-a', '--audio_path', type=str, default='examples/wav/TEST_MEETING_T0000000001_S00000.wav', help="Input audio path") + parser.add_argument('-a', '--audio_dir', type=str, default='examples/wav', help="Path to input audio directory") parser.add_argument('-d', '--device', type=str, default='cuda', help="Target inference device") parser.add_argument('-p', '--profile', action='store_true', help='Enable torch profiler') args = parser.parse_args() - audio_path = args.audio_path + audio_dir = args.audio_dir model_path = args.model_path device = args.device enable_profile = args.profile @@ -116,9 +143,9 @@ def benchmark(model, wav_path, batch, warmpup=2, trials=10, enable_profile=False model = load_model(model_path) if enable_profile: - rps, avg_rtf, avg_latency = benchmark(model, audio_path, batch=1, enable_profile=True) + rps, avg_rtf, avg_latency = benchmark(model, audio_dir, batch=1, enable_profile=True) else: for batch in batch_sizes: print(f"*************************** batch size {batch} ***************************") - rps, avg_latency, avg_rtf = benchmark(model, audio_path, batch=batch) + rps, avg_latency, avg_rtf = benchmark(model, audio_dir, batch=batch) print(f"batch size: {batch}, average latency: {avg_latency:.3f}s | RPS: {rps:.2f}, avg RTF: {avg_rtf:.3f}") From b9bf68ee088ab05f8612d4178874a3945f5e491d Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Fri, 7 Nov 2025 08:45:36 +0000 Subject: [PATCH 09/13] Refactor benchmark run benchmark with audio directory with variable length data --- examples/benchmark_firered_asr.py | 102 +++++++++++------- .../models/module/transformer_decoder.py | 11 ++ 2 files changed, 77 insertions(+), 36 deletions(-) diff --git a/examples/benchmark_firered_asr.py b/examples/benchmark_firered_asr.py index 2cfa0ef..a15cc90 100755 --- a/examples/benchmark_firered_asr.py +++ b/examples/benchmark_firered_asr.py @@ -35,38 +35,11 @@ def load_audio(wav_path): audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) return audio -def benchmark(model, audio_dir, batch, warmpup=2, trials=10, enable_profile=False): - # Get list of .wav files (case-insensitive) - batch_wav_path = [] - - # Collect file paths and durations - file_durations = [] - input_wav_path_list = [os.path.join(audio_dir, f) - for f in os.listdir(audio_dir) - if f.lower().endswith('.wav')] - - for file_path in input_wav_path_list: - try: - y, sr = librosa.load(file_path, sr=None) # keep original sampling rate - duration = len(y) / sr - file_durations.append((file_path, duration)) - except Exception as e: - print(f"Error processing {file_path}: {e}") - - # Sort by duration (longest first) - file_durations.sort(key=lambda x: x[1], reverse=True) - - # Take top N for batch - batch_wav_path = [fp for fp, dur in file_durations[:batch]] - - # Optional: print results - for i, (fp, dur) in enumerate(file_durations[:batch], start=1): - print(f"{i}. {fp} - {dur:.2f} sec") - - batch_uttid = list(range(batch)) +def run(model, batch_wav_path, warmpup=2, trials=10, enable_profile=False, offset=0): + batch_uttid = list(range(offset, offset + len(batch_wav_path), 1)) results = None total_dur = None - + preprocess_start = time.time() feats, lengths, durs = model.feat_extractor(batch_wav_path) feats = feats.to(torch.bfloat16) @@ -74,8 +47,8 @@ def benchmark(model, audio_dir, batch, warmpup=2, trials=10, enable_profile=Fals preprocess_dur = time.time() - preprocess_start print(f"preprocess duration: {preprocess_dur:.3f} s") total_dur = sum(durs) - print(f"total input audio duration: {total_dur:.3f} s") - + avg_audio_dur_per_sample = total_dur / len(durs) + print(f"total input audio duration: {total_dur:.3f} s, avg input audio duration per sample: {avg_audio_dur_per_sample:.3f} s") # Warmup print("==========warmup========") for _ in range(warmpup): @@ -125,7 +98,56 @@ def benchmark(model, audio_dir, batch, warmpup=2, trials=10, enable_profile=Fals for res in results[-batch:]: print(res) avg_rtf = sum(rtf_list) / len(rtf_list) - return rps, avg_latency, avg_rtf + print(f"Finished benchmark test for batch size: {len(batch_wav_path)}, average latency: {avg_latency:.3f}s | RPS: {rps:.2f}, avg RTF: {avg_rtf:.3f}") + + return rps, avg_latency, avg_rtf, avg_audio_dur_per_sample, results[-batch:] + +def benchmark(model, audio_dir, batch, warmpup=2, trials=10, enable_profile=False): + # Get list of .wav files (case-insensitive) + batch_wav_path = [] + + # Collect file paths and durations + file_durations = [] + input_wav_path_list = [os.path.join(audio_dir, f) + for f in os.listdir(audio_dir) + if f.lower().endswith('.wav')] + + for file_path in input_wav_path_list: + try: + y, sr = librosa.load(file_path, sr=None) # keep original sampling rate + duration = len(y) / sr + file_durations.append((file_path, duration)) + except Exception as e: + print(f"Error processing {file_path}: {e}") + + # Sort by duration (longest first) + file_durations.sort(key=lambda x: x[1], reverse=True) + # Optional: print all audio file with duration + for i, (fp, dur) in enumerate(file_durations, start=1): + print(f"{i}. {fp} - {dur:.2f} sec") + + dataset_size = len(input_wav_path_list) + # Loop through data in batches + benchmark_results = [] + e2e_start = time.time() + for start in range(0, dataset_size - dataset_size % batch, batch): + #batch_wav_path, dur = file_durations[start:start+batch] + batch_wav_path = [path for path, _ in file_durations[start:start + batch]] + print(f"Processing {batch} batched data from index {start} to {start + batch-1}") + rps, avg_latency, avg_rtf, avg_audio_dur_per_sample, model_results = run(model, batch_wav_path, warmpup, trials, enable_profile, offset=start) + benchmark_results.append((batch, avg_audio_dur_per_sample, avg_latency, rps, avg_rtf, model_results)) + + # Process remaining data if any + remainder = dataset_size % batch + if remainder: + #last_batch_wav_path = file_durations[-remainder:] + last_batch_wav_path = [path for path, _ in file_durations[-remainder:]] + print(f"Processing {remainder} remaining data : {last_batch_wav_path}") + rps, avg_latency, avg_rtf, avg_audio_dur_per_sample, model_results = run(model, last_batch_wav_path, warmpup, trials, enable_profile, offset=start) + benchmark_results.append((batch, avg_audio_dur_per_sample, avg_latency, rps, avg_rtf, model_results)) + e2e_duration = time.time() - e2e_start + + return benchmark_results, e2e_duration if __name__ == "__main__": parser = argparse.ArgumentParser(prog='Benchmark scripts for FireRedASR', usage='%(prog)s [options]') @@ -143,9 +165,17 @@ def benchmark(model, audio_dir, batch, warmpup=2, trials=10, enable_profile=Fals model = load_model(model_path) if enable_profile: - rps, avg_rtf, avg_latency = benchmark(model, audio_dir, batch=1, enable_profile=True) + #rps, avg_rtf, avg_latency = benchmark(model, audio_dir, batch=1, enable_profile=True) + benchmark_results, e2e_duration = benchmark(model, audio_dir, batch=1, enable_profile=enable_profile) else: for batch in batch_sizes: print(f"*************************** batch size {batch} ***************************") - rps, avg_latency, avg_rtf = benchmark(model, audio_dir, batch=batch) - print(f"batch size: {batch}, average latency: {avg_latency:.3f}s | RPS: {rps:.2f}, avg RTF: {avg_rtf:.3f}") + #rps, avg_latency, avg_rtf = benchmark(model, audio_dir, batch=batch) + benchmark_results, e2e_duration = benchmark(model, audio_dir, batch=batch, enable_profile=enable_profile) + + #print(f"batch size: {batch}, average latency: {avg_latency:.3f}s | RPS: {rps:.2f}, avg RTF: {avg_rtf:.3f}") + print(f"\nbatch size: {batch}, e2e latency: {e2e_duration} s") + for res in benchmark_results: + print(res[5]) + for res in benchmark_results: + print(f"batch size: {res[0]}, avg audio duration per sample: {res[1]:.3f} s, avg inference latency {res[2]:.3f} s | RPS: {res[3]:.2f}, avg RTF: {res[4]:.3f}") diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index bd0c0bf..a81b0f6 100755 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -49,6 +49,10 @@ def __init__( def batch_beam_search(self, encoder_outputs, src_masks, beam_size=1, nbest=1, decode_max_len=0, softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): + if ATTENTION_BACKEND.upper() == "XFORMERS": + for dec_layer in self.layer_stack: + dec_layer.self_attn.attention.reset_attn_bias() + dec_layer.cross_attn.attention.reset_attn_bias() B = beam_size N, Ti, H = encoder_outputs.size() device = encoder_outputs.device @@ -148,6 +152,7 @@ def batch_beam_search(self, encoder_outputs, src_masks, } n_nbest_hyps.append(new_hyp) nbest_hyps.append(n_nbest_hyps) + return nbest_hyps def ignored_target_position_is_0(self, padded_targets, ignore_id): @@ -362,6 +367,9 @@ def set_cross_attn_bias(self, mask, bs, q_len, k_len): def get_attn_bias(self): return self.attn_bias + def reset_attn_bias(self): + self.attn_bias = None + # xFormers Attention class DecoderXFormersAttention(nn.Module): def __init__(self, n_head, d_k, d_model, temperature, attention_type): @@ -372,6 +380,9 @@ def __init__(self, n_head, d_k, d_model, temperature, attention_type): self.d_model = d_model self.attention_metadata = XFormersAttentionMetadata(attention_type) + def reset_attn_bias(self): + self.attention_metadata.reset_attn_bias() + def forward(self, q, k, v, mask=None): original_query = q bs = q.size(0) From e91314dec700225ce402698a4f08563f8e3bd8d0 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Wed, 12 Nov 2025 15:49:28 +0000 Subject: [PATCH 10/13] Algin beam search parameters, set model default precision as FP16 --- examples/benchmark_firered_asr.py | 36 ++++++++++++++++++------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/examples/benchmark_firered_asr.py b/examples/benchmark_firered_asr.py index a15cc90..4acf51c 100755 --- a/examples/benchmark_firered_asr.py +++ b/examples/benchmark_firered_asr.py @@ -3,7 +3,7 @@ import torch import numpy as np from tqdm import tqdm - +import json import librosa import soundfile as sf import argparse @@ -12,7 +12,7 @@ from fireredasr.models.fireredasr import FireRedAsr from torch.profiler import profile as torch_profiler -from torch.profiler import ProfilerActivity, record_function +from torch.profiler import ProfilerActivity ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "XFORMERS") # Option: "NATIVE", "SDPA", "XFORMERS" @@ -21,7 +21,7 @@ def load_model(model_path="pretrained_models/FireRedASR-AED-L"): print("==========Load model:========") model = FireRedAsr.from_pretrained("aed", model_path) - model.model.to(torch.bfloat16) + model.model.to(torch.float16) model.model.cuda() model.model.eval() @@ -42,7 +42,7 @@ def run(model, batch_wav_path, warmpup=2, trials=10, enable_profile=False, offse preprocess_start = time.time() feats, lengths, durs = model.feat_extractor(batch_wav_path) - feats = feats.to(torch.bfloat16) + feats = feats.to(torch.float16) feats, lengths = feats.cuda(), lengths.cuda() preprocess_dur = time.time() - preprocess_start print(f"preprocess duration: {preprocess_dur:.3f} s") @@ -53,7 +53,7 @@ def run(model, batch_wav_path, warmpup=2, trials=10, enable_profile=False, offse print("==========warmup========") for _ in range(warmpup): with torch.no_grad(): - _ = model.model.transcribe(feats, lengths) + _ = model.model.transcribe(feats, lengths, beam_size=3, nbest=1, decode_max_len=0, softmax_smoothing=1.25, length_penalty=0.6, eos_penalty=1.0) # repetition_penalty=1.0, decode_min_len=0, temperature=1.0 used only for llm # Benchmark print("==========start benchmark========") @@ -73,11 +73,11 @@ def run(model, batch_wav_path, warmpup=2, trials=10, enable_profile=False, offse record_shapes=True, with_stack=True, profile_memory=False) as prof: - hyps = model.model.transcribe(feats, lengths) + hyps = model.model.transcribe(feats, lengths, beam_size=3, nbest=1, decode_max_len=0, softmax_smoothing=1.25, length_penalty=0.6, eos_penalty=1.0) # repetition_penalty=1.0, decode_min_len=0, temperature=1.0 used only for llm print(prof.key_averages().table(sort_by="cuda_time_total")) prof.export_chrome_trace(f"firered_asr_profile_{batch}_{ATTENTION_BACKEND}.json") else: - hyps = model.model.transcribe(feats, lengths) + hyps = model.model.transcribe(feats, lengths, beam_size=3, nbest=1, decode_max_len=0, softmax_smoothing=1.25, length_penalty=0.6, eos_penalty=1.0) # repetition_penalty=1.0, decode_min_len=0, temperature=1.0 used only for llm total_time += time.time() - start elapsed = time.time() - start @@ -92,8 +92,7 @@ def run(model, batch_wav_path, warmpup=2, trials=10, enable_profile=False, offse avg_latency = total_time / trials rps = batch / avg_latency - #Only print last result for debug purpose - #print("results[0]: ", results[0]) + # Only print last result for debug purpose print("Only print last run results for debug purpose...") for res in results[-batch:]: print(res) @@ -131,7 +130,6 @@ def benchmark(model, audio_dir, batch, warmpup=2, trials=10, enable_profile=Fals benchmark_results = [] e2e_start = time.time() for start in range(0, dataset_size - dataset_size % batch, batch): - #batch_wav_path, dur = file_durations[start:start+batch] batch_wav_path = [path for path, _ in file_durations[start:start + batch]] print(f"Processing {batch} batched data from index {start} to {start + batch-1}") rps, avg_latency, avg_rtf, avg_audio_dur_per_sample, model_results = run(model, batch_wav_path, warmpup, trials, enable_profile, offset=start) @@ -140,11 +138,11 @@ def benchmark(model, audio_dir, batch, warmpup=2, trials=10, enable_profile=Fals # Process remaining data if any remainder = dataset_size % batch if remainder: - #last_batch_wav_path = file_durations[-remainder:] + start+=batch last_batch_wav_path = [path for path, _ in file_durations[-remainder:]] print(f"Processing {remainder} remaining data : {last_batch_wav_path}") rps, avg_latency, avg_rtf, avg_audio_dur_per_sample, model_results = run(model, last_batch_wav_path, warmpup, trials, enable_profile, offset=start) - benchmark_results.append((batch, avg_audio_dur_per_sample, avg_latency, rps, avg_rtf, model_results)) + benchmark_results.append((remainder, avg_audio_dur_per_sample, avg_latency, rps, avg_rtf, model_results)) e2e_duration = time.time() - e2e_start return benchmark_results, e2e_duration @@ -165,17 +163,25 @@ def benchmark(model, audio_dir, batch, warmpup=2, trials=10, enable_profile=Fals model = load_model(model_path) if enable_profile: - #rps, avg_rtf, avg_latency = benchmark(model, audio_dir, batch=1, enable_profile=True) benchmark_results, e2e_duration = benchmark(model, audio_dir, batch=1, enable_profile=enable_profile) else: for batch in batch_sizes: print(f"*************************** batch size {batch} ***************************") - #rps, avg_latency, avg_rtf = benchmark(model, audio_dir, batch=batch) benchmark_results, e2e_duration = benchmark(model, audio_dir, batch=batch, enable_profile=enable_profile) - #print(f"batch size: {batch}, average latency: {avg_latency:.3f}s | RPS: {rps:.2f}, avg RTF: {avg_rtf:.3f}") print(f"\nbatch size: {batch}, e2e latency: {e2e_duration} s") + save_results = [] + save_path = f"ATTENTION_BACKEND_{ATTENTION_BACKEND}_bs_{batch}_output.json" for res in benchmark_results: print(res[5]) + save_results+=res[5] for res in benchmark_results: print(f"batch size: {res[0]}, avg audio duration per sample: {res[1]:.3f} s, avg inference latency {res[2]:.3f} s | RPS: {res[3]:.2f}, avg RTF: {res[4]:.3f}") + with open(save_path, "w", encoding="utf-8") as final: + json.dump(save_results, + final, + indent=2, + ensure_ascii=False, # Keep non-ASCII characters intact + default=lambda x: list(x) if isinstance(x, tuple) else str(x) + ) + print(f"Performance results written to {save_path}") From f63eeb42738abfa14f1ffd0149b55ab4a56a7777 Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Wed, 12 Nov 2025 15:50:38 +0000 Subject: [PATCH 11/13] Bugfix for xFormers attention backend --- .../models/module/transformer_decoder.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index a81b0f6..1629e8c 100755 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -190,10 +190,10 @@ class DecoderLayer(nn.Module): def __init__(self, d_model, n_head, dropout): super().__init__() self.self_attn_norm = nn.LayerNorm(d_model) - self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) + self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout, attention_type="self_attention") self.cross_attn_norm = nn.LayerNorm(d_model) - self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) + self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout, attention_type="cross_attention") self.mlp_norm = nn.LayerNorm(d_model) self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) @@ -338,7 +338,7 @@ def set_self_attn_bias(self): else: print("Unknown attention type used, only support `self_attention`") - def set_cross_attn_bias(self, mask, bs, q_len, k_len): + def set_cross_attn_bias(self, mask, bs, q_len, k_len, n_head, dtype, device): if self.attention_type == "cross_attention": mask = mask.to(torch.bool) @@ -347,22 +347,22 @@ def set_cross_attn_bias(self, mask, bs, q_len, k_len): mask = mask.expand(bs, 1, q_len, k_len) # Expand mask for all heads - mask = mask.expand(bs, self.n_head, q_len, k_len) \ - .reshape(bs * self.n_head, q_len, k_len) + mask = mask.expand(bs, n_head, q_len, k_len) \ + .reshape(bs * n_head, q_len, k_len) # Alignment requirement for xformers: pad allocation to multiple of 8 pad_k = ((k_len + 7) // 8) * 8 pad_q = ((q_len + 7) // 8) * 8 - bias_full = torch.zeros(bs * self.n_head, pad_q, pad_k, - dtype=q.dtype, device=q.device) + bias_full = torch.zeros(bs * n_head, pad_q, pad_k, + dtype=dtype, device=device) bias_full[:, :q_len, :k_len].masked_fill_(~mask, float("-inf")) # Slice down to actual shape but keep aligned backing storage self.attn_bias = bias_full[:, :q_len, :k_len] - - print("Unknown attention type used, only support `self_attention` and `cross_attention`") + else: + print("Unknown attention type used, only support `cross_attention`") def get_attn_bias(self): return self.attn_bias @@ -391,35 +391,35 @@ def forward(self, q, k, v, mask=None): k_len = k.size(2) # seq_len_k dtype = q.dtype - q = q.reshape(bs * self.n_head, -1, self.d_k).to(torch.bfloat16) - k = k.reshape(bs * self.n_head, -1, self.d_k).to(torch.bfloat16) - v = v.reshape(bs * self.n_head, -1, self.d_k).to(torch.bfloat16) + q = q.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) + k = k.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) + v = v.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) output = None if bs == 1: output = xops.memory_efficient_attention(q, k, v) else: attn_bias = None - # --- AUTO-DETECT causal self-attention --- - # q and k are the same tensor object in memory when this is pure self-attn + # --- causal self-attention --- + # q and k has same length, pass attn_bias=None if self.attention_metadata.attention_type == "self_attention": - attn_bias = xops.LowerTriangularMask() + attn_bias = None # --- Cross-attention / padding mask --- - #elif mask is not None: elif self.attention_metadata.attention_type == "cross_attention" and mask is not None: if self.attention_metadata.get_attn_bias() == None: - self.attention_metadata.set_cross_attn_bias(mask, bs, q_len, k_len) + self.attention_metadata.set_cross_attn_bias(mask, bs, q_len, k_len, self.n_head, q.dtype, q.device) attn_bias = self.attention_metadata.get_attn_bias() + else: + print("Unknown attention type used, only support `self_attention` and `cross_attention`") # --- Run memory-efficient attention --- output = xops.memory_efficient_attention(q, k, v, - attn_bias=attn_bias if attn_bias is not None else None) + attn_bias=attn_bias) # reshape back to (bs, seq_len, d_model) return output.view_as(original_query).to(dtype) - class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super().__init__() From 56a739760faefdbc7f2d4fbf0e3fd4a9a4a043bb Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Mon, 17 Nov 2025 02:43:51 +0000 Subject: [PATCH 12/13] Reuse cached encoder kv proj for cross-attention in decoding phase --- .../models/module/transformer_decoder.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 1629e8c..8de3226 100755 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -51,8 +51,10 @@ def batch_beam_search(self, encoder_outputs, src_masks, softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): if ATTENTION_BACKEND.upper() == "XFORMERS": for dec_layer in self.layer_stack: - dec_layer.self_attn.attention.reset_attn_bias() dec_layer.cross_attn.attention.reset_attn_bias() + + for dec_layer in self.layer_stack: + dec_layer.cross_attn.clear_states() B = beam_size N, Ti, H = encoder_outputs.size() device = encoder_outputs.device @@ -255,6 +257,11 @@ def __init__(self, d_model, n_head, dropout=0.1, attention_type=None): exit(1) self.fc = nn.Linear(n_head * self.d_k, d_model) self.dropout = nn.Dropout(dropout) + self.attention_type = attention_type + self.kv_proj = None + + def clear_states(self): + self.kv_proj = None def forward(self, q, k, v, mask=None): bs = q.size(0) @@ -262,6 +269,17 @@ def forward(self, q, k, v, mask=None): q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) + if self.attention_type=="cross_attention": + # cross attention reuse the same k,v projection throughout decoding phase + if self.kv_proj is None: + self.kv_proj = ( + self.w_ks(k).view(bs, -1, self.n_head, self.d_k), + self.w_vs(v).view(bs, -1, self.n_head, self.d_k) + ) + k,v = self.kv_proj + else: + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -332,12 +350,6 @@ def __init__(self, attention_type): self.attention_type = attention_type self.attn_bias = None - def set_self_attn_bias(self): - if self.attention_type == "self_attention": - self.attn_bias = xops.LowerTriangularMask() - else: - print("Unknown attention type used, only support `self_attention`") - def set_cross_attn_bias(self, mask, bs, q_len, k_len, n_head, dtype, device): if self.attention_type == "cross_attention": mask = mask.to(torch.bool) From cb0a06adb83055e2a606a9f5c795627204d5f43c Mon Sep 17 00:00:00 2001 From: Xiake Sun Date: Mon, 17 Nov 2025 10:06:36 +0000 Subject: [PATCH 13/13] Simpify xformer backend, remove unused code --- .../models/module/transformer_decoder.py | 44 +++++++------------ 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 8de3226..896021f 100755 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -1,11 +1,10 @@ +import os from typing import List, Optional, Dict import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor -import math -import os try: import xformers.ops as xops @@ -192,10 +191,10 @@ class DecoderLayer(nn.Module): def __init__(self, d_model, n_head, dropout): super().__init__() self.self_attn_norm = nn.LayerNorm(d_model) - self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout, attention_type="self_attention") + self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout, is_cross=False) self.cross_attn_norm = nn.LayerNorm(d_model) - self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout, attention_type="cross_attention") + self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout, is_cross=True) self.mlp_norm = nn.LayerNorm(d_model) self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) @@ -230,7 +229,7 @@ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, class DecoderMultiHeadAttention(nn.Module): - def __init__(self, d_model, n_head, dropout=0.1, attention_type=None): + def __init__(self, d_model, n_head, dropout=0.1, is_cross = False): super().__init__() self.d_model = d_model self.n_head = n_head @@ -251,25 +250,23 @@ def __init__(self, d_model, n_head, dropout=0.1, attention_type=None): if not xformers_available: print("ATTENTION_BACKEND='XFORMERS' selected, but the xformers package is not available. Please install xformers") exit(1) - self.attention = DecoderXFormersAttention(self.n_head, self.d_k, self.d_model, temperature=self.d_k ** 0.5, attention_type=attention_type) + self.attention = DecoderXFormersAttention(self.n_head, self.d_k, self.d_model, temperature=self.d_k ** 0.5, is_cross=is_cross) else: print("Unsupported attention backend: ", ATTENTION_BACKEND) exit(1) self.fc = nn.Linear(n_head * self.d_k, d_model) self.dropout = nn.Dropout(dropout) - self.attention_type = attention_type + self.is_cross = is_cross self.kv_proj = None def clear_states(self): self.kv_proj = None - def forward(self, q, k, v, mask=None): + def forward(self, q, k, v, mask=None, cross_kv_cache=None): bs = q.size(0) q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) - k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) - v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) - if self.attention_type=="cross_attention": + if self.is_cross: # cross attention reuse the same k,v projection throughout decoding phase if self.kv_proj is None: self.kv_proj = ( @@ -343,15 +340,14 @@ def forward(self, q, k, v, mask=None): return output - class XFormersAttentionMetadata: """Metadata for XFormers Attention backend """ - def __init__(self, attention_type): - self.attention_type = attention_type + def __init__(self, is_cross): + self.is_cross = is_cross self.attn_bias = None def set_cross_attn_bias(self, mask, bs, q_len, k_len, n_head, dtype, device): - if self.attention_type == "cross_attention": + if self.is_cross: mask = mask.to(torch.bool) # If mask only has 1 in q_len dimension, expand it @@ -373,8 +369,6 @@ def set_cross_attn_bias(self, mask, bs, q_len, k_len, n_head, dtype, device): # Slice down to actual shape but keep aligned backing storage self.attn_bias = bias_full[:, :q_len, :k_len] - else: - print("Unknown attention type used, only support `cross_attention`") def get_attn_bias(self): return self.attn_bias @@ -384,13 +378,14 @@ def reset_attn_bias(self): # xFormers Attention class DecoderXFormersAttention(nn.Module): - def __init__(self, n_head, d_k, d_model, temperature, attention_type): + def __init__(self, n_head, d_k, d_model, temperature, is_cross): super().__init__() self.temperature = temperature self.n_head = n_head self.d_k = d_k self.d_model = d_model - self.attention_metadata = XFormersAttentionMetadata(attention_type) + self.attention_metadata = XFormersAttentionMetadata(is_cross) + self.is_cross = is_cross def reset_attn_bias(self): self.attention_metadata.reset_attn_bias() @@ -412,25 +407,18 @@ def forward(self, q, k, v, mask=None): output = xops.memory_efficient_attention(q, k, v) else: attn_bias = None - # --- causal self-attention --- - # q and k has same length, pass attn_bias=None - if self.attention_metadata.attention_type == "self_attention": - attn_bias = None - # --- Cross-attention / padding mask --- - elif self.attention_metadata.attention_type == "cross_attention" and mask is not None: + if self.is_cross and mask is not None: if self.attention_metadata.get_attn_bias() == None: self.attention_metadata.set_cross_attn_bias(mask, bs, q_len, k_len, self.n_head, q.dtype, q.device) attn_bias = self.attention_metadata.get_attn_bias() - else: - print("Unknown attention type used, only support `self_attention` and `cross_attention`") # --- Run memory-efficient attention --- output = xops.memory_efficient_attention(q, k, v, attn_bias=attn_bias) # reshape back to (bs, seq_len, d_model) - return output.view_as(original_query).to(dtype) + return output.to(dtype) class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1):