diff --git a/fireredasr/models/fireredasr_aed.py b/fireredasr/models/fireredasr_aed.py index 4938c2c..522a46d 100644 --- a/fireredasr/models/fireredasr_aed.py +++ b/fireredasr/models/fireredasr_aed.py @@ -4,6 +4,7 @@ from fireredasr.models.module.transformer_decoder import TransformerDecoder +@torch.compile(mode="max-autotune") class FireRedAsrAed(torch.nn.Module): @classmethod def from_args(cls, args): @@ -28,6 +29,7 @@ def transcribe(self, padded_input, input_lengths, beam_size=1, nbest=1, decode_max_len=0, softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): enc_outputs, _, enc_mask = self.encoder(padded_input, input_lengths) + self.decoder.clear() nbest_hyps = self.decoder.batch_beam_search( enc_outputs, enc_mask, beam_size, nbest, decode_max_len, diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 2088b08..2b083c0 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -1,10 +1,49 @@ from typing import List, Optional, Dict +import inspect import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor - +import math +import os +import torch +import torch.nn as nn +import xformers.ops as xops + +from flash_attn import flash_attn_func, flash_attn_varlen_func + +import einops + +ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "SDPA") # Option: "NATIVE", "SDPA", "XFORMERS", "FLASH_ATTN" +MultiHeadAttention = None +print("ATTENTION_BACKEND: ", ATTENTION_BACKEND) + + +class AttentionMeta(object): + + def __init__(self): + self.seq_lens = None + self.cu_seqlens_q = None + self.cu_seqlens_k = None + self.max_seqlen_q = None + self.max_seqlen_k = None + self.total_seqlen_k = None + self.active_indices = None + + def update(self, seq_lens=None, + cu_seqlens_q=None, cu_seqlens_k=None, + max_seqlen_q=None, max_seqlen_k=None, + total_seqlen_k=None, + active_indices=None): + self.seq_lens = seq_lens + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + self.max_seqlen_q = max_seqlen_q + self.max_seqlen_k = max_seqlen_k + self.total_seqlen_k = total_seqlen_k + self.active_indices = active_indices + class TransformerDecoder(nn.Module): def __init__( @@ -22,7 +61,6 @@ def __init__( # Components self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id) self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen) - self.dropout = nn.Dropout(residual_dropout) self.layer_stack = nn.ModuleList() for l in range(n_layers): @@ -34,85 +72,152 @@ def __init__( self.tgt_word_prj.weight = self.tgt_word_emb.weight self.scale = (d_model ** 0.5) + + self.scores_map = {} + self.stride_map = {} + self.filter_indexes = {} + self.active_masks = {} + + + def clear(self): + for dec_layer in self.layer_stack: + dec_layer.clear() + + + def cal_seq_lens(self, mask): + mask = mask.squeeze(1) + return mask.sum(dim=1, dtype=torch.int32) + def batch_beam_search(self, encoder_outputs, src_masks, beam_size=1, nbest=1, decode_max_len=0, softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): B = beam_size N, Ti, H = encoder_outputs.size() + M = N * B device = encoder_outputs.device maxlen = decode_max_len if decode_max_len > 0 else Ti assert eos_penalty > 0.0 and eos_penalty <= 1.0 # Init - encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H) - src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti) - ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device) + encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(M, Ti, H) + src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(M, -1, Ti) + raw_ys = ys = torch.ones(M, 1, device=device).fill_(self.sos_id).long() caches: List[Optional[Tensor]] = [] for _ in range(self.n_layers): - caches.append(None) - scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device) - scores = scores.repeat(N).view(N*B, 1) + caches.append(torch.empty(M, 0, H, device=device, dtype=encoder_outputs.dtype)) + + if B not in self.scores_map: + scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device) + self.scores_map[B] = scores + scores = self.scores_map[B].repeat(N).view(M, 1) + finished_mask_score = self.scores_map[B] is_finished = torch.zeros_like(scores) + + if (B, N) not in self.stride_map: + stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(M).to(device).long() + filter_index = torch.arange(M, device=device, dtype=torch.int32) + active_mask = torch.ones(M, dtype=torch.bool, device=device) + self.stride_map[(B, N)] = stride + self.filter_indexes[M] = filter_index + self.active_masks[M] = active_mask + stride = self.stride_map[(B, N)] + active_mask = self.active_masks[M] + active_indices = self.filter_indexes[M] + last_t_logit = None + + attn_meta = AttentionMeta() # Autoregressive Prediction for t in range(maxlen): - tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id) - - dec_output = self.dropout( - self.tgt_word_emb(ys) * self.scale + - self.positional_encoding(ys)) - - i = 0 - for dec_layer in self.layer_stack: + + def expand(f, mask, indices, value=0.0, t=None): + if t is None: + t = torch.full((len(mask), *list(f.shape[1:])), value, dtype=f.dtype, device=f.device) + t[indices] = f + return t + + dec_output = self.tgt_word_emb(ys) * self.scale + self.positional_encoding(ys) + dec_output = dec_output[active_indices] + t_encoder_outputs = encoder_outputs[active_indices] + t_src_mask = src_mask[active_indices] + + if ATTENTION_BACKEND in {"FLASH_ATTN"}: + tgt_mask = raw_ys[active_indices] + seq_lens = self.cal_seq_lens(t_src_mask) + seq_lens_cpu = seq_lens.cpu() + total_seqlen_k, max_seqlen_k = seq_lens_cpu.sum().item(), seq_lens_cpu.max().item() + attn_meta.update(seq_lens=seq_lens, + max_seqlen_k=max_seqlen_k, + total_seqlen_k=total_seqlen_k, + active_indices=active_indices) + else: + tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id) + tgt_mask = tgt_mask[active_indices] + + for i, dec_layer in enumerate(self.layer_stack): dec_output = dec_layer.forward( - dec_output, encoder_outputs, - tgt_mask, src_mask, - cache=caches[i]) - caches[i] = dec_output - i += 1 + dec_output, + t_encoder_outputs, + tgt_mask, + t_src_mask, + cache=caches[i][active_indices], + attn_meta=attn_meta) + caches[i] = dec_output dec_output = self.layer_norm_out(dec_output) t_logit = self.tgt_word_prj(dec_output[:, -1]) + if last_t_logit is None: + last_t_logit = t_logit + else: + last_t_logit[active_indices] = t_logit + t_logit = last_t_logit t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1) - + if eos_penalty != 1.0: t_scores[:, self.eos_id] *= eos_penalty t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1) - t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished) + t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished, mask_score=finished_mask_score) t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished) # Accumulated scores = scores + t_topB_scores - # Pruning + # Pruning scores = scores.view(N, B*B) scores, topB_score_ids = torch.topk(scores, k=B, dim=1) scores = scores.view(-1, 1) - - topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B) - stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device) - topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long() + + topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(M) + topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride # Update ys ys = ys[topB_row_number_in_ys] - t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1) + t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(M, 1) ys = torch.cat((ys, t_ys), dim=1) # Update caches new_caches: List[Optional[Tensor]] = [] + target = torch.full((len(active_mask), *list(caches[0].shape[1:])), + 0.0, + dtype=caches[0].dtype, + device=caches[0].device) for cache in caches: - if cache is not None: - new_caches.append(cache[topB_row_number_in_ys]) + new_caches.append(expand(cache, active_mask, active_indices,t=target)[topB_row_number_in_ys]) caches = new_caches # Update finished state is_finished = t_ys.eq(self.eos_id) - if is_finished.sum().item() == N*B: - break + is_finished_n = is_finished.sum().item() + active_mask = ~is_finished.squeeze() + #active_indices = self.filter_indexes[M][active_mask] + active_indices = torch.nonzero_static(active_mask, size=M - int(is_finished_n)).squeeze(1) + if is_finished_n == M: + break + # Length penalty (follow GNMT) scores = scores.view(N, B) ys = ys.view(N, B, -1) @@ -123,28 +228,22 @@ def batch_beam_search(self, encoder_outputs, src_masks, nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1) nbest_scores = -1.0 * nbest_scores index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long() - nbest_ys = ys.view(N*B, -1)[index.view(-1)] - nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1) - nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1) - - # result - nbest_hyps: List[List[Dict[str, Tensor]]] = [] - for n in range(N): - n_nbest_hyps: List[Dict[str, Tensor]] = [] - for i, score in enumerate(nbest_scores[n]): - new_hyp = { - "yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]] - } - n_nbest_hyps.append(new_hyp) - nbest_hyps.append(n_nbest_hyps) - return nbest_hyps + nbest_ys = ys.view(M, -1)[index.view(-1)].view(N, nbest_ids.size(1), -1) + nbest_ys_lengths = ys_lengths.view(M)[index.view(-1)].view(N, -1) + + return [ + [ + {"yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]]} + for i, _ in enumerate(nbest_scores[n]) + ] + for n in range(N) + ] def ignored_target_position_is_0(self, padded_targets, ignore_id): mask = torch.ne(padded_targets, ignore_id) mask = mask.unsqueeze(dim=1) T = padded_targets.size(-1) - upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype) - upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device) + upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype).to(mask.device) return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8) def upper_triangular_is_0(self, size): @@ -152,10 +251,9 @@ def upper_triangular_is_0(self, size): tri_left_ones = torch.tril(ones) return tri_left_ones.to(torch.uint8) - def set_finished_beam_score_to_zero(self, scores, is_finished): + def set_finished_beam_score_to_zero(self, scores, is_finished, mask_score): NB, B = scores.size() is_finished = is_finished.float() - mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device) mask_score = mask_score.view(1, B).repeat(NB, 1) return scores * (1 - is_finished) + mask_score * is_finished @@ -169,49 +267,60 @@ def get_ys_lengths(self, ys): return ys_lengths.int() - class DecoderLayer(nn.Module): def __init__(self, d_model, n_head, dropout): super().__init__() self.self_attn_norm = nn.LayerNorm(d_model) - self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) + if ATTENTION_BACKEND.upper() == "NATIVE": + MultiHeadAttention = DecoderMultiHeadAttention + elif ATTENTION_BACKEND.upper() == "SDPA": + MultiHeadAttention = DecoderMHATorchSDPA + elif ATTENTION_BACKEND.upper() == "XFORMERS": + MultiHeadAttention = DecoderXFormersAttention + elif ATTENTION_BACKEND.upper() == "FLASH_ATTN": + MultiHeadAttention = DecoderMHAFlashAttn + else: + print("Unsupported attention backend: ", ATTENTION_BACKEND) + exit(1) + self.self_attn = MultiHeadAttention(d_model, n_head, dropout) self.cross_attn_norm = nn.LayerNorm(d_model) - self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) + self.cross_attn = MultiHeadAttention(d_model, n_head, dropout) self.mlp_norm = nn.LayerNorm(d_model) self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) + + def clear(self): + self.cross_attn.clear() def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, - cache=None): + cache=None, attn_meta:AttentionMeta=None): x = dec_input residual = x x = self.self_attn_norm(x) - if cache is not None: + if cache.shape[1]: xq = x[:, -1:, :] residual = residual[:, -1:, :] - self_attn_mask = self_attn_mask[:, -1:, :] + self_attn_mask = self.self_attn.parse_mask(self_attn_mask) else: xq = x - x = self.self_attn(xq, x, x, mask=self_attn_mask) - x = residual + x - + x = residual + self.self_attn(xq, x, x, mask=self_attn_mask) residual = x - x = self.cross_attn_norm(x) - x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask) - x = residual + x + x = residual+self.cross_attn(self.cross_attn_norm(x), + enc_output, + enc_output, + mask=cross_attn_mask, + is_cross=True, + attn_meta=attn_meta) residual = x x = self.mlp_norm(x) x = residual + self.mlp(x) - if cache is not None: - x = torch.cat([cache, x], dim=1) - + x = torch.cat([cache, x], dim=1) return x - - -class DecoderMultiHeadAttention(nn.Module): + +class BaseMultiHeadAttention(nn.Module): def __init__(self, d_model, n_head, dropout=0.1): super().__init__() self.d_model = d_model @@ -221,13 +330,24 @@ def __init__(self, d_model, n_head, dropout=0.1): self.w_qs = nn.Linear(d_model, n_head * self.d_k) self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * self.d_k) - - self.attention = DecoderScaledDotProductAttention( - temperature=self.d_k ** 0.5) self.fc = nn.Linear(n_head * self.d_k, d_model) - self.dropout = nn.Dropout(dropout) - def forward(self, q, k, v, mask=None): + def forward(self, q, k, v, mask=None, is_cross=False, attn_meta:AttentionMeta=None): + raise NotImplementedError + + def clear(self): + pass + + def parse_mask(self, mask): + return mask[:, -1:, :] + +# Native MHA +class DecoderMultiHeadAttention(BaseMultiHeadAttention): + def __init__(self, d_model, n_head, dropout=0.1): + super().__init__(d_model, n_head, dropout) + self.attention = DecoderScaledDotProductAttention(temperature=self.d_k ** 0.5) + + def forward(self, q, k, v, mask=None, is_cross=False, attn_meta:AttentionMeta=None): bs = q.size(0) q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) @@ -244,11 +364,9 @@ def forward(self, q, k, v, mask=None): output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) output = self.fc(output) - output = self.dropout(output) return output - class DecoderScaledDotProductAttention(nn.Module): def __init__(self, temperature): super().__init__() @@ -266,21 +384,209 @@ def forward(self, q, k, v, mask=None): output = torch.matmul(attn, v) return output +# MHA with Torch SDPA +class DecoderMHATorchSDPA(BaseMultiHeadAttention): + def __init__(self, d_model, n_head, dropout=0.1): + super().__init__(d_model, n_head, dropout) + self.attention = DecoderTorchSDPA(temperature=self.d_k ** 0.5) + + def forward(self, q, k, v, mask=None, is_cross=False, attn_meta:AttentionMeta=None): + bs = q.size(0) + + q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + + output = self.attention(q, k, v, mask=mask.unsqueeze(1)) + output = self.fc(output.transpose(1, 2).contiguous().view(bs, -1, self.d_model)) + return output + +class DecoderTorchSDPA(nn.Module): + def __init__(self, temperature): + super().__init__() + self.temperature = temperature + + def forward(self, q, k, v, mask=None): + """ + q, k, v: (batch, num_heads, seq_len, d_k) + mask: optional attention mask + - If boolean: shape (batch, 1, seq_len, seq_len) or broadcastable. + True means 'mask out'. + - If float: same shape, with -inf for masked positions. + """ + output = F.scaled_dot_product_attention( + q, k, v, + attn_mask=mask.eq(1), + dropout_p=0.0, # set >0 only during training + is_causal=False, # set True to get causal masking automatically + ) + return output + + +# xFormers Attention +class DecoderXFormersAttention(BaseMultiHeadAttention): + def __init__(self, d_model, n_head, dropout=0.1): + super().__init__(d_model, n_head, dropout) + self.attn_bias = None + + def set_cross_attn_bias(self, mask, bs, q_len, k_len, n_head, dtype, device, is_cross=False): + if is_cross: + mask = mask.to(torch.bool) + + # If mask only has 1 in q_len dimension, expand it + if mask.size(2) == 1 and q_len > 1: + mask = mask.expand(bs, 1, q_len, k_len) + + # Expand mask for all heads + mask = mask.expand(bs, n_head, q_len, k_len) \ + .reshape(bs * n_head, q_len, k_len) + + # Alignment requirement for xformers: pad allocation to multiple of 8 + pad_k = ((k_len + 7) // 8) * 8 + pad_q = ((q_len + 7) // 8) * 8 + + bias_full = torch.zeros(bs * n_head, pad_q, pad_k, + dtype=dtype, device=device) + + bias_full[:, :q_len, :k_len].masked_fill_(~mask, float("-inf")) + + # Slice down to actual shape but keep aligned backing storage + self.attn_bias = bias_full[:, :q_len, :k_len] + else: + print("Unknown attention type used, only support `cross_attention`") + + def get_attn_bias(self): + return self.attn_bias + + def reset_attn_bias(self): + self.attn_bias = None + + def forward(self, q, k, v, mask=None, is_cross=False, attn_meta:AttentionMeta=None): + bs = q.size(0) + + q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k).transpose(1, 2) + + if mask is not None: + mask = mask.unsqueeze(1) + + original_query = q + + # Save lengths + q_len = q.size(2) # seq_len_q + k_len = k.size(2) # seq_len_k + dtype = q.dtype + + q = q.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) + k = k.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) + v = v.reshape(bs * self.n_head, -1, self.d_k).to(torch.float16) + + output = None + if bs == 1: + output = xops.memory_efficient_attention(q, k, v) + else: + attn_bias = None + # --- causal self-attention --- + # q and k has same length, pass attn_bias=None + if not is_cross: + attn_bias = None + + # --- Cross-attention / padding mask --- + elif is_cross and mask is not None: + self.set_cross_attn_bias(mask, bs, q_len, k_len, self.n_head, q.dtype, q.device, is_cross=is_cross) + attn_bias = self.get_attn_bias() + else: + print("Unknown attention type used, only support `self_attention` and `cross_attention`") + + # --- Run memory-efficient attention --- + + output = xops.memory_efficient_attention(q, k, v, + attn_bias=attn_bias) + # reshape back to (bs, seq_len, d_model) + output = output.view_as(original_query).to(dtype) + output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) + output = self.fc(output) + return output + + def clear(self): + self.reset_attn_bias() + + def parse_mask(self, mask): + return mask + + +class DecoderMHAFlashAttn(BaseMultiHeadAttention): + def __init__(self, d_model, n_head, dropout=0.1): + super().__init__(d_model, n_head, dropout) + + self.cross_cache_k = None + self.cross_cache_v = None + + def clear(self): + self.cross_cache_k = None + self.cross_cache_v = None + + def forward(self, q, k, v, mask=None, is_cross=False, attn_meta:AttentionMeta=None): + is_casual = not is_cross + bs = q.size(0) + + if is_casual: + q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) + output = flash_attn_func(q, k, v, causal=is_casual) + output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) + else: + mask = mask.squeeze(1) + bool_mask = mask.view(-1).bool() + q = self.w_qs(q).view(-1, self.n_head, self.d_k) + + if self.cross_cache_k is None or self.cross_cache_v is None: + k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) + v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) + + self.cross_cache_k = k + self.cross_cache_v = v + + active_indices = attn_meta.active_indices + seq_lens = attn_meta.seq_lens + total, _max_seq_k = attn_meta.total_seqlen_k, attn_meta.max_seqlen_k + bool_indices = torch.nonzero_static(bool_mask, size=total).squeeze(1) + var_k = self.cross_cache_k[active_indices].view(-1, self.n_head, self.d_k)[bool_indices] + var_v = self.cross_cache_v[active_indices].view(-1, self.n_head, self.d_k)[bool_indices] + cu_seqlens_q = torch.arange(0, bs + 1, 1, device=seq_lens.device, dtype=torch.int32) + cu_seqlens_k = torch.zeros(bs + 1, device=seq_lens.device, dtype=torch.int32) + cu_seqlens_k[1:] = torch.cumsum(seq_lens, dim=0) + output = flash_attn_varlen_func(q=q, + k=var_k, + v=var_v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=1, + max_seqlen_k=_max_seq_k, + causal=is_casual) + output = output.contiguous().view(bs, -1, self.d_model) + output = self.fc(output) + return output + + def parse_mask(self, mask): + return mask + +# @torch.compile(mode="reduce-overhead", backend="inductor") class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super().__init__() self.w_1 = nn.Linear(d_model, d_ff) self.act = nn.GELU() self.w_2 = nn.Linear(d_ff, d_model) - self.dropout = nn.Dropout(dropout) def forward(self, x): output = self.w_2(self.act(self.w_1(x))) - output = self.dropout(output) return output - +# @torch.compile(mode="reduce-overhead", backend="inductor") class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__()