-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Rope asr aed ngpt #13933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rope asr aed ngpt #13933
Changes from 3 commits
6d287a3
c785c1a
3f7cfaf
5ef8eb9
82f3c98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,9 +16,11 @@ | |
| from collections import OrderedDict | ||
| from dataclasses import dataclass | ||
|
|
||
| from rotary_embedding_torch import RotaryEmbedding | ||
|
||
| import torch | ||
| import torch.distributed | ||
| import torch.nn as nn | ||
| from rotary_embedding_torch import RotaryEmbedding | ||
|
|
||
| from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling | ||
| from nemo.core.classes.common import typecheck | ||
|
|
@@ -27,6 +29,7 @@ | |
| from nemo.core.classes.module import NeuralModule | ||
| from nemo.core.neural_types import AcousticEncodedRepresentation, LengthsType, NeuralType, SpectrogramType | ||
| from nemo.utils import logging | ||
| INF_VAL = 10000.0 | ||
|
|
||
| try: | ||
| from flash_attn import flash_attn_func | ||
|
|
@@ -174,6 +177,41 @@ def forward_for_export( | |
|
|
||
| def streaming_post_process(self, rets, keep_all_outputs=True): | ||
| raise NotImplementedError() | ||
| def _create_masks(self, att_context_size, padding_length, max_audio_length, offset, device): | ||
|
|
||
| att_mask = torch.ones(1, max_audio_length, max_audio_length, dtype=torch.bool, device=device) | ||
|
|
||
|
|
||
| if att_context_size[0] >= 0: | ||
| att_mask = att_mask.triu(diagonal=-att_context_size[0]) | ||
| if att_context_size[1] >= 0: | ||
| att_mask = att_mask.tril(diagonal=att_context_size[1]) | ||
|
|
||
| # pad_mask is the masking to be used to ignore paddings | ||
| pad_mask = torch.arange(0, max_audio_length, device=device).expand( | ||
| padding_length.size(0), -1 | ||
| ) < padding_length.unsqueeze(-1) | ||
|
|
||
| if offset is not None: | ||
| pad_mask_off = torch.arange(0, max_audio_length, device=device).expand( | ||
| padding_length.size(0), -1 | ||
| ) >= offset.unsqueeze(-1) | ||
| pad_mask = pad_mask_off.logical_and(pad_mask) | ||
|
|
||
| if att_mask is not None: | ||
| # pad_mask_for_att_mask is the mask which helps to ignore paddings | ||
| pad_mask_for_att_mask = pad_mask.unsqueeze(1).repeat([1, max_audio_length, 1]) | ||
| pad_mask_for_att_mask = torch.logical_and(pad_mask_for_att_mask, pad_mask_for_att_mask.transpose(1, 2)) | ||
| # att_mask is the masking to be used by the MHA layers to ignore the tokens not supposed to be visible | ||
| att_mask = att_mask[:, :max_audio_length, :max_audio_length] | ||
| # paddings should also get ignored, so pad_mask_for_att_mask is used to ignore their corresponding scores | ||
| att_mask = torch.logical_and(pad_mask_for_att_mask, att_mask.to(pad_mask_for_att_mask.device)) | ||
| att_mask = ~att_mask | ||
|
|
||
| pad_mask = ~pad_mask | ||
| return pad_mask, att_mask | ||
|
|
||
|
|
||
|
|
||
| @typecheck() | ||
| def forward( | ||
|
|
@@ -197,7 +235,15 @@ def forward_internal( | |
|
|
||
| audio_signal = audio_signal.transpose(1, 2) | ||
| x, length = self.pre_encode(x=audio_signal, lengths=length) | ||
| x = self.ngpt(x) | ||
| padding_length = length | ||
| pad_mask, att_mask = self._create_masks( | ||
| att_context_size=[-1,-1], | ||
| padding_length=padding_length, | ||
| max_audio_length=x.size(1), | ||
| offset=None, | ||
| device=x.device, | ||
| ) | ||
| x = self.ngpt(x, mask=att_mask) | ||
| x = x.transpose(1, 2) | ||
|
|
||
| return x, length | ||
|
|
@@ -208,36 +254,15 @@ def normalize_matrices(self): | |
| self.ngpt.normalize_matrices() | ||
|
|
||
|
|
||
| def apply_rotary_position_embeddings(sinusoidal_pos, q, k): | ||
| # Split the sinusoidal_pos into sin and cos parts | ||
| sin, cos = sinusoidal_pos.chunk(2, dim=-1) | ||
| # Apply the rotary embeddings to the query and key | ||
| q_rot = torch.stack((-q[..., 1::2], q[..., ::2]), dim=-1) | ||
| k_rot = torch.stack((-k[..., 1::2], k[..., ::2]), dim=-1) | ||
| q_rot = torch.reshape(q_rot, q.shape[:-1] + (q.shape[-1] // 2, 2)) * torch.stack((cos, sin), dim=-1) | ||
| k_rot = torch.reshape(k_rot, k.shape[:-1] + (k.shape[-1] // 2, 2)) * torch.stack((cos, sin), dim=-1) | ||
| q_rot = torch.reshape(q_rot, q.shape) | ||
| k_rot = torch.reshape(k_rot, k.shape) | ||
| return q_rot.to(q.dtype), k_rot.to(k.dtype) | ||
|
|
||
|
|
||
| def get_sinusoidal_embeddings(n_positions, dim, device): | ||
| """Generate sinusoidal positional embeddings.""" | ||
| position = torch.arange(n_positions, dtype=torch.float, device=device).unsqueeze(1) | ||
| div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float, device=device) * (-math.log(10000.0) / dim)) | ||
| sinusoidal_emb = torch.empty((n_positions, dim), device=device) | ||
| sinusoidal_emb[:, 0::2] = torch.sin(position * div_term) | ||
| sinusoidal_emb[:, 1::2] = torch.cos(position * div_term) | ||
| return sinusoidal_emb | ||
|
|
||
|
|
||
| def justnorm(x, fp32: bool = False, idim: int = -1): | ||
| def justnorm(x, fp32: bool = False, idim: int = -1,eps: float = 1e-10): | ||
|
||
|
|
||
| if fp32: | ||
| dtype = x.dtype | ||
| x = x.float() | ||
| res = (x / x.norm(p=2, dim=idim, keepdim=True)).to(dtype=dtype) | ||
| else: | ||
| res = x / x.norm(p=2, dim=idim, keepdim=True) | ||
| norm = x.norm(p=2, dim=idim, keepdim=True) | ||
| res = x / (norm + eps) | ||
| return res | ||
|
|
||
|
|
||
|
|
@@ -247,15 +272,16 @@ def justnorm_fp32(x, idim: int = -1): | |
|
|
||
| class Block(nn.Module): | ||
|
|
||
| def __init__(self, config, iblock): | ||
| def __init__(self, config, rotary_emb): | ||
| super().__init__() | ||
| self.config = config | ||
|
|
||
| self.rotary_emb = rotary_emb | ||
| self.key = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | ||
| self.query = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | ||
| self.value = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | ||
| self.att_c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | ||
|
|
||
| self.rotary_emb = RotaryEmbedding(dim =config.n_embd // config.n_head) | ||
| self.c_fc = nn.Linear(config.n_embd, 2 * 4 * config.n_embd, bias=config.bias) | ||
| self.silu = nn.SiLU() | ||
| self.mlp_c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) | ||
|
|
@@ -301,37 +327,48 @@ def forward(self, h, mask): | |
| q = q.view(B, T, self.config.n_head, self.config.n_embd // self.config.n_head) | ||
| k = k.view(B, T, self.config.n_head, self.config.n_embd // self.config.n_head) | ||
| v = v.view(B, T, self.config.n_head, self.config.n_embd // self.config.n_head) | ||
|
|
||
| sinusoidal_pos = get_sinusoidal_embeddings(T, self.config.n_embd // self.config.n_head, device=q.device) | ||
| q, k = apply_rotary_position_embeddings(sinusoidal_pos, q.transpose(1, 2), k.transpose(1, 2)) | ||
| q = self.rotary_emb.rotate_queries_or_keys(q.transpose(2, 1)) | ||
| k = self.rotary_emb.rotate_queries_or_keys(k.transpose(2, 1)) | ||
| q = q.transpose(2, 1) | ||
| k = k.transpose(2, 1) | ||
| B, T, H, D = q.shape | ||
|
|
||
| # Apply nGPT specific normalization if enabled | ||
| if self.config.use_nGPT == 1: | ||
| sqk = (self.sqk * (self.sqk_init_value / self.sqk_init_scaling)).view( | ||
| 1, 1, self.config.n_head, self.config.n_embd // self.config.n_head | ||
| ) | ||
| q = sqk * justnorm(q) | ||
| k = sqk * justnorm(k) | ||
|
|
||
| # Compute attention scaling factor | ||
| sqrt_head_dim = (self.config.n_embd / self.config.n_head) ** 0.5 | ||
| if self.config.use_nGPT == 0: | ||
| softmax_scale = 1.0 / sqrt_head_dim | ||
| if self.config.use_nGPT == 1: | ||
| softmax_scale = sqrt_head_dim | ||
| y = flash_attn_func( | ||
| q, | ||
| k, | ||
| v, | ||
| dropout_p=0.0, | ||
| softmax_scale=softmax_scale, | ||
| causal=False, | ||
| window_size=(-1, -1), | ||
| alibi_slopes=None, | ||
| deterministic=False, | ||
| ) | ||
| y = y.to(dtype=q.dtype) | ||
| y = y.contiguous().view(B, T, self.config.n_embd) | ||
| softmax_scale = sqrt_head_dim if self.config.use_nGPT == 1 else 1.0 / sqrt_head_dim | ||
|
|
||
| # Reshape tensors for attention computation | ||
| q_ = q.permute(0, 2, 1, 3) # (B, H, T, D) | ||
| k_ = k.permute(0, 2, 1, 3) # (B, H, T, D) | ||
| v_ = v.permute(0, 2, 1, 3) # (B, H, T, D) | ||
|
|
||
| # Compute attention scores | ||
| scores = torch.matmul(q_, k_.transpose(-2, -1)) * softmax_scale # (B, H, T, T) | ||
|
|
||
| # Apply attention mask if provided | ||
| if mask is not None: | ||
| scores = scores.masked_fill(mask.unsqueeze(1), -INF_VAL) | ||
|
|
||
| # ---- (4) softmax & (optional) dropout ---- | ||
| attn = torch.softmax(scores, dim=-1) # (B, H, T, T) | ||
| attn = attn.to(v_.dtype) # ⇐ match v_ (BF16) | ||
|
|
||
| if mask is not None: | ||
| attn = attn.masked_fill(mask.unsqueeze(1), 0.0) | ||
|
|
||
| out = torch.matmul(attn, v_) # (B, H, T, D) | ||
|
|
||
| y = out.permute(0, 2, 1, 3).contiguous().view(B, T, H * D) | ||
|
|
||
|
|
||
| h_att = self.att_c_proj(y) | ||
|
|
||
|
|
@@ -407,11 +444,12 @@ def __init__(self, config): | |
| super().__init__() | ||
| self.config = config | ||
|
|
||
| self.rotary_emb = RotaryEmbedding(dim=config.n_embd // config.n_head) | ||
| self.transformer = nn.ModuleDict( | ||
| dict( | ||
| # wte=nn.Embedding(config.vocab_size, config.n_embd), | ||
| # drop=nn.Dropout(config.dropout), | ||
| h=nn.ModuleList([Block(config, il) for il in range(config.n_layer)]) | ||
| h=nn.ModuleList([Block(config, rotary_emb=self.rotary_emb) for il in range(config.n_layer)]) | ||
| ) | ||
| ) | ||
| # self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why remove torch.no_grad?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed it. Updated