Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,10 +629,9 @@ def optimizer_step(
return ans

def normalize_matrices(self):
with torch.no_grad():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove torch.no_grad?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed it. Updated

for module in self.modules():
if hasattr(module, "normalize_matrices"):
module.normalize_matrices()
for module in self.children():
if hasattr(module, "normalize_matrices"):
module.normalize_matrices()

def on_train_start(self):
super().on_train_start()
Expand Down
136 changes: 87 additions & 49 deletions nemo/collections/asr/modules/ngpt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from collections import OrderedDict
from dataclasses import dataclass

from rotary_embedding_torch import RotaryEmbedding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double import

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

import torch
import torch.distributed
import torch.nn as nn
from rotary_embedding_torch import RotaryEmbedding

from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling
from nemo.core.classes.common import typecheck
Expand All @@ -27,6 +29,7 @@
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types import AcousticEncodedRepresentation, LengthsType, NeuralType, SpectrogramType
from nemo.utils import logging
INF_VAL = 10000.0

try:
from flash_attn import flash_attn_func
Expand Down Expand Up @@ -174,6 +177,41 @@ def forward_for_export(

def streaming_post_process(self, rets, keep_all_outputs=True):
raise NotImplementedError()
def _create_masks(self, att_context_size, padding_length, max_audio_length, offset, device):

att_mask = torch.ones(1, max_audio_length, max_audio_length, dtype=torch.bool, device=device)


if att_context_size[0] >= 0:
att_mask = att_mask.triu(diagonal=-att_context_size[0])
if att_context_size[1] >= 0:
att_mask = att_mask.tril(diagonal=att_context_size[1])

# pad_mask is the masking to be used to ignore paddings
pad_mask = torch.arange(0, max_audio_length, device=device).expand(
padding_length.size(0), -1
) < padding_length.unsqueeze(-1)

if offset is not None:
pad_mask_off = torch.arange(0, max_audio_length, device=device).expand(
padding_length.size(0), -1
) >= offset.unsqueeze(-1)
pad_mask = pad_mask_off.logical_and(pad_mask)

if att_mask is not None:
# pad_mask_for_att_mask is the mask which helps to ignore paddings
pad_mask_for_att_mask = pad_mask.unsqueeze(1).repeat([1, max_audio_length, 1])
pad_mask_for_att_mask = torch.logical_and(pad_mask_for_att_mask, pad_mask_for_att_mask.transpose(1, 2))
# att_mask is the masking to be used by the MHA layers to ignore the tokens not supposed to be visible
att_mask = att_mask[:, :max_audio_length, :max_audio_length]
# paddings should also get ignored, so pad_mask_for_att_mask is used to ignore their corresponding scores
att_mask = torch.logical_and(pad_mask_for_att_mask, att_mask.to(pad_mask_for_att_mask.device))
att_mask = ~att_mask

pad_mask = ~pad_mask
return pad_mask, att_mask



@typecheck()
def forward(
Expand All @@ -197,7 +235,15 @@ def forward_internal(

audio_signal = audio_signal.transpose(1, 2)
x, length = self.pre_encode(x=audio_signal, lengths=length)
x = self.ngpt(x)
padding_length = length
pad_mask, att_mask = self._create_masks(
att_context_size=[-1,-1],
padding_length=padding_length,
max_audio_length=x.size(1),
offset=None,
device=x.device,
)
x = self.ngpt(x, mask=att_mask)
x = x.transpose(1, 2)

return x, length
Expand All @@ -208,36 +254,15 @@ def normalize_matrices(self):
self.ngpt.normalize_matrices()


def apply_rotary_position_embeddings(sinusoidal_pos, q, k):
# Split the sinusoidal_pos into sin and cos parts
sin, cos = sinusoidal_pos.chunk(2, dim=-1)
# Apply the rotary embeddings to the query and key
q_rot = torch.stack((-q[..., 1::2], q[..., ::2]), dim=-1)
k_rot = torch.stack((-k[..., 1::2], k[..., ::2]), dim=-1)
q_rot = torch.reshape(q_rot, q.shape[:-1] + (q.shape[-1] // 2, 2)) * torch.stack((cos, sin), dim=-1)
k_rot = torch.reshape(k_rot, k.shape[:-1] + (k.shape[-1] // 2, 2)) * torch.stack((cos, sin), dim=-1)
q_rot = torch.reshape(q_rot, q.shape)
k_rot = torch.reshape(k_rot, k.shape)
return q_rot.to(q.dtype), k_rot.to(k.dtype)


def get_sinusoidal_embeddings(n_positions, dim, device):
"""Generate sinusoidal positional embeddings."""
position = torch.arange(n_positions, dtype=torch.float, device=device).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float, device=device) * (-math.log(10000.0) / dim))
sinusoidal_emb = torch.empty((n_positions, dim), device=device)
sinusoidal_emb[:, 0::2] = torch.sin(position * div_term)
sinusoidal_emb[:, 1::2] = torch.cos(position * div_term)
return sinusoidal_emb


def justnorm(x, fp32: bool = False, idim: int = -1):
def justnorm(x, fp32: bool = False, idim: int = -1,eps: float = 1e-10):
Copy link
Collaborator

@pzelasko pzelasko Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of 1e-10 which might be not representable in low precision types, use torch.finfo(x.dtype).eps https://docs.pytorch.org/docs/stable/type_info.html#torch.torch.finfo

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


if fp32:
dtype = x.dtype
x = x.float()
res = (x / x.norm(p=2, dim=idim, keepdim=True)).to(dtype=dtype)
else:
res = x / x.norm(p=2, dim=idim, keepdim=True)
norm = x.norm(p=2, dim=idim, keepdim=True)
res = x / (norm + eps)
return res


Expand All @@ -247,15 +272,16 @@ def justnorm_fp32(x, idim: int = -1):

class Block(nn.Module):

def __init__(self, config, iblock):
def __init__(self, config, rotary_emb):
super().__init__()
self.config = config

self.rotary_emb = rotary_emb
self.key = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.query = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.value = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.att_c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

self.rotary_emb = RotaryEmbedding(dim =config.n_embd // config.n_head)
self.c_fc = nn.Linear(config.n_embd, 2 * 4 * config.n_embd, bias=config.bias)
self.silu = nn.SiLU()
self.mlp_c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
Expand Down Expand Up @@ -301,37 +327,48 @@ def forward(self, h, mask):
q = q.view(B, T, self.config.n_head, self.config.n_embd // self.config.n_head)
k = k.view(B, T, self.config.n_head, self.config.n_embd // self.config.n_head)
v = v.view(B, T, self.config.n_head, self.config.n_embd // self.config.n_head)

sinusoidal_pos = get_sinusoidal_embeddings(T, self.config.n_embd // self.config.n_head, device=q.device)
q, k = apply_rotary_position_embeddings(sinusoidal_pos, q.transpose(1, 2), k.transpose(1, 2))
q = self.rotary_emb.rotate_queries_or_keys(q.transpose(2, 1))
k = self.rotary_emb.rotate_queries_or_keys(k.transpose(2, 1))
q = q.transpose(2, 1)
k = k.transpose(2, 1)
B, T, H, D = q.shape

# Apply nGPT specific normalization if enabled
if self.config.use_nGPT == 1:
sqk = (self.sqk * (self.sqk_init_value / self.sqk_init_scaling)).view(
1, 1, self.config.n_head, self.config.n_embd // self.config.n_head
)
q = sqk * justnorm(q)
k = sqk * justnorm(k)

# Compute attention scaling factor
sqrt_head_dim = (self.config.n_embd / self.config.n_head) ** 0.5
if self.config.use_nGPT == 0:
softmax_scale = 1.0 / sqrt_head_dim
if self.config.use_nGPT == 1:
softmax_scale = sqrt_head_dim
y = flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
)
y = y.to(dtype=q.dtype)
y = y.contiguous().view(B, T, self.config.n_embd)
softmax_scale = sqrt_head_dim if self.config.use_nGPT == 1 else 1.0 / sqrt_head_dim

# Reshape tensors for attention computation
q_ = q.permute(0, 2, 1, 3) # (B, H, T, D)
k_ = k.permute(0, 2, 1, 3) # (B, H, T, D)
v_ = v.permute(0, 2, 1, 3) # (B, H, T, D)

# Compute attention scores
scores = torch.matmul(q_, k_.transpose(-2, -1)) * softmax_scale # (B, H, T, T)

# Apply attention mask if provided
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(1), -INF_VAL)

# ---- (4) softmax & (optional) dropout ----
attn = torch.softmax(scores, dim=-1) # (B, H, T, T)
attn = attn.to(v_.dtype) # ⇐ match v_ (BF16)

if mask is not None:
attn = attn.masked_fill(mask.unsqueeze(1), 0.0)

out = torch.matmul(attn, v_) # (B, H, T, D)

y = out.permute(0, 2, 1, 3).contiguous().view(B, T, H * D)


h_att = self.att_c_proj(y)

Expand Down Expand Up @@ -407,11 +444,12 @@ def __init__(self, config):
super().__init__()
self.config = config

self.rotary_emb = RotaryEmbedding(dim=config.n_embd // config.n_head)
self.transformer = nn.ModuleDict(
dict(
# wte=nn.Embedding(config.vocab_size, config.n_embd),
# drop=nn.Dropout(config.dropout),
h=nn.ModuleList([Block(config, il) for il in range(config.n_layer)])
h=nn.ModuleList([Block(config, rotary_emb=self.rotary_emb) for il in range(config.n_layer)])
)
)
# self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
Expand Down
Loading