diff --git a/megatron/model/rotary_pos_embedding.py b/megatron/model/rotary_pos_embedding.py index 4d4497e0cd9..d36cc71490c 100644 --- a/megatron/model/rotary_pos_embedding.py +++ b/megatron/model/rotary_pos_embedding.py @@ -20,8 +20,9 @@ def __init__(self, dim, theta=10000): raise RuntimeError("einops is required for Rotary Embedding") def forward(self, max_seq_len, offset=0): - seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset - freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq) + seq = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=torch.float) + offset + # Force float32 since bfloat16 loses precision on long contexts + freqs = einsum('i , j -> i j', seq, self.inv_freq.float()) # first part even vector components, second part odd vector components, # 2 * dim in dimension size emb = torch.cat((freqs, freqs), dim=-1)