Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 111 additions & 11 deletions src/fairseq2/nn/_position_encoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -311,6 +311,7 @@
max_seq_len: int
theta: float
freqs_init_fn: Callable[[RotaryEncoder], Tensor] | None
impl: str

Check failure on line 314 in src/fairseq2/nn/_position_encoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

trailing whitespace

def __init__(
self,
Expand All @@ -320,6 +321,7 @@
theta: float = 10_000.0,
freqs_init_fn: Callable[[RotaryEncoder], Tensor] | None = None,
device: Device | None = None,
impl: str = "llama"
Copy link

@djsaunde djsaunde Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could use Literal["value1", "value2", ...] typing for better imputation.

) -> None:
"""
:param encoding_dim: The dimensionality of positional encodings. The
Expand All @@ -334,8 +336,14 @@
expected for the callable to return a :class:`~torch.Tensor` holding
the frequency table. If ``None``, the frequencies will be initialized
as described in the reference paper.
:param impl: Changes the embedding dimension ordering by using consecutive
tensors as a real/img pair ("llama") or using the split-half pairing ("reference").
Example: E = 8: [1,2,3,4,5,6,7,8]
- "llama": [(1,2), (3,4), (5,6), (7,8)] := [real0, imag0, real1, imag1, real2, imag2, real3, imag3]
- "reference": [(1,5), (2,6), (3,7), (4,8)] := [real0, real1, real2, real3, imag0, imag1, imag2, imag3]

:raise ValueError: when ``encoding_dim`` is not even.
:raise ValueError: when ``impl`` is not a valid implementation selection
"""
super().__init__(encoding_dim)

Expand All @@ -344,6 +352,12 @@
f"`encoding_dim` must be even, but is {encoding_dim} instead."
)

if impl not in ["llama", "reference"]:
raise ValueError(
f"`impl` must be one of [\"llama\", \"reference\"], but is {impl} instead."
)

# (S+1, E / 2, 2)
freqs = torch.empty(
(max_seq_len + 1, encoding_dim // 2, 2), device=device, dtype=torch.float32
)
Expand All @@ -356,6 +370,8 @@

self.freqs_init_fn = freqs_init_fn

self.impl = impl

self.reset_parameters()

def reset_parameters(self) -> None:
Expand Down Expand Up @@ -430,8 +446,11 @@

# (S, E / 2) -> (1, S, E / 2)
complex_freqs = complex_freqs.unsqueeze(0)

Check failure on line 449 in src/fairseq2/nn/_position_encoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
if self.impl == "reference":
seqs = self._split_to_consecutive_layout(tensor=seqs)

# ([N], S, *, E) -> ([N], S, *, E / 2, 2)
# ([N], S, *, E) -> ([N], S, *, E / 2, 2)
seqs = seqs.unflatten(-1, (-1, 2))

# ([N], S, *, E / 2, 2) -> ([N], S, *, E / 2)
Expand All @@ -445,7 +464,47 @@
# ([N], S, *, E / 2) -> ([N], S, *, E)
fp32_seqs = torch.view_as_real(complex_seqs).flatten(-2)

if self.impl == "reference":
fp32_seqs = self._consecutive_to_split_layout(tensor=fp32_seqs)

return fp32_seqs.type_as(seqs)

Check failure on line 471 in src/fairseq2/nn/_position_encoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
def _consecutive_to_split_layout(self,tensor: torch.Tensor) -> torch.Tensor:
"""
Transforms consecutive pairs to split layout: [1,2,3,4,5,6,7,8] -> [1,3,5,7,2,4,6,8]
"""
original_shape = tensor.shape
encoding_dim = original_shape[-1]
half_dim = encoding_dim // 2

Check failure on line 479 in src/fairseq2/nn/_position_encoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
# (*, E) -> (*, E / 2, 2)
pairs = tensor.view(*original_shape[:-1], half_dim, 2)

Check failure on line 482 in src/fairseq2/nn/_position_encoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
# (*, E / 2)
real_parts = pairs[..., 0]
# (*, E / 2)
imag_parts = pairs[..., 1]

Check failure on line 487 in src/fairseq2/nn/_position_encoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
# (*, E / 2) -> (*, E)
return torch.cat([real_parts, imag_parts], dim=-1)

Check failure on line 490 in src/fairseq2/nn/_position_encoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
def _split_to_consecutive_layout(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Transforms split into consecutive layout: [1,3,5,7,2,4,6,8] -> [1,2,3,4,5,6,7,8]
"""
original_shape = tensor.shape
encoding_dim = original_shape[-1]
half_dim = encoding_dim // 2

Check failure on line 498 in src/fairseq2/nn/_position_encoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
# (*, E) -> (*, E / 2)
real_parts = tensor[..., :half_dim]
# (*, E) -> (*, E / 2)
imag_parts = tensor[..., half_dim:]

Check failure on line 503 in src/fairseq2/nn/_position_encoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
# (*, E / 2, 2) -> (*, E)
pairs = torch.stack([real_parts, imag_parts], dim=-1)
# tuples to original view
return pairs.view(*original_shape)

@override
def extra_repr(self) -> str:
Expand Down Expand Up @@ -475,6 +534,7 @@
sin_freqs: Tensor
max_seq_len: int
theta: float
impl: str

def __init__(
self,
Expand All @@ -483,6 +543,7 @@
*,
theta: float = 10_000.0,
device: Device | None = None,
impl: str = "reference",
) -> None:
"""
:param encoding_dim: The dimensionality of positional encodings. The
Expand All @@ -492,8 +553,14 @@
Sequences longer than ``max_seq_len`` will cause a :class:`ValueError`.
:param theta: The coefficient of the long-term decay as described in
section 3.3 of the reference paper.
:param impl: Changes the embedding dimension ordering by using consecutive
tensors as a real/img pair ("llama") or using the split-half pairing ("reference").
Example: E = 8: [1,2,3,4,5,6,7,8]
- "llama": [(1,2), (3,4), (5,6), (7,8)] := [real0, imag0, real1, imag1, real2, imag2, real3, imag3]
- "reference": [(1,5), (2,6), (3,7), (4,8)] := [real0, real1, real2, real3, imag0, imag1, imag2, imag3]

:raise ValueError: when ``encoding_dim`` is not even.
:raise ValueError: when ``impl`` is not a valid implementation selection.
"""
super().__init__(encoding_dim)

Expand All @@ -502,6 +569,11 @@
f"`encoding_dim` must be even, but is {encoding_dim} instead."
)

if impl not in ["reference", "llama"]:
raise ValueError(
f"`impl` must be one of [\"reference\", \"llama\"], but is {impl} instead."
)

cos_freqs = torch.empty(
(max_seq_len + 1, encoding_dim), device=device, dtype=torch.float32
)
Expand All @@ -517,6 +589,8 @@

self.theta = theta

self.impl = impl

self.reset_parameters()

def reset_parameters(self) -> None:
Expand All @@ -532,10 +606,10 @@

encoding_dim = self.encoding_dim

# (E)
# (E / 2)
indices = torch.arange(encoding_dim // 2, device=device, dtype=dtype)

# (E) -> (1, E)
# (E / 2) -> (1, E / 2)
indices = indices.unsqueeze(0)

# (S)
Expand All @@ -544,17 +618,27 @@
# (S, 1)
steps = steps.unsqueeze(1)

# (S, 1) x (1, E) -> (S, E)
# (S, 1) x (1, E / 2) -> (S, E / 2)
table = torch.matmul(steps, self.theta ** (-2.0 * indices / encoding_dim))

cos = torch.cos(table)
sin = torch.sin(table)

self.cos_freqs[1:, : encoding_dim // 2] = cos
self.cos_freqs[1:, encoding_dim // 2 :] = cos

self.sin_freqs[1:, : encoding_dim // 2] = sin
self.sin_freqs[1:, encoding_dim // 2 :] = sin
if self.impl == "reference":
# Split-half layout: [real0, real1, real2, real3, imag0, imag1, imag2, imag3]
self.cos_freqs[1:, : encoding_dim // 2] = cos
self.cos_freqs[1:, encoding_dim // 2 :] = cos

self.sin_freqs[1:, : encoding_dim // 2] = sin
self.sin_freqs[1:, encoding_dim // 2 :] = sin
else: # llama
# Consecutive layout: [real0, imag0, real1, imag1, real2, imag2, real3, imag3]
for i in range(encoding_dim // 2):
self.cos_freqs[1:, 2*i] = cos[:, i]
self.cos_freqs[1:, 2*i + 1] = cos[:, i]

Check failure on line 639 in src/fairseq2/nn/_position_encoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
self.sin_freqs[1:, 2*i] = sin[:, i]
self.sin_freqs[1:, 2*i + 1] = sin[:, i]

@override
def forward(
Expand Down Expand Up @@ -607,25 +691,41 @@

fp32_seqs = seqs.float()

fp32_rotated_seqs = self._rotate_half_way(fp32_seqs)
if self.impl == "reference":
fp32_rotated_seqs = self._rotate_half_way(fp32_seqs)
else: # llama
fp32_rotated_seqs = self._reorder_to_consecutive_pairs(fp32_seqs)

fp32_seqs = (fp32_seqs * cos_freqs) + (fp32_rotated_seqs * sin_freqs)

return fp32_seqs.type_as(seqs)

def _rotate_half_way(self, seqs: Tensor) -> Tensor:
"""Rotation for split-half layout: [1,2,3,4,5,6,7,8] -> [-5,-6,-7,-8,1,2,3,4]"""
half1 = seqs[..., : self.encoding_dim // 2]
half2 = seqs[..., self.encoding_dim // 2 :]

return torch.cat((-half2, half1), dim=-1)

def _reorder_to_consecutive_pairs(self, seqs: Tensor) -> Tensor:
"""Rotation for consecutive layout: [1,2,3,4,5,6,7,8] -> [-2,1,-4,3,-6,5,-8,7]"""
even_parts = seqs[..., 0::2]
odd_parts = seqs[..., 1::2]

result = torch.zeros_like(seqs)
result[..., 0::2] = -odd_parts
result[..., 1::2] = even_parts

return result

@override
def extra_repr(self) -> str:
""":meta private:"""
return (
f"encoding_dim={self.encoding_dim}, "
f"max_seq_len={self.max_seq_len}, "
f"theta={self.theta}"
f"theta={self.theta}, "
f"impl={self.impl}"
)


Expand Down
Loading