Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ mscale: 1.0
rope_interleave: True # RoPE with sin/cos interleaved vs concatenated
rope_truncate: True # Floor lower bound and ceil upper bound for correction range
rope_attention_scaling: False # Scale the rotary embedding output
rope_use_rotation_matrix: False # Whether to use rotation matrix for YaRN.

# Ahead of time Compilation (aka AOT)
# Only set these arguments if you are running train_compile or loading a compiled train step.
Expand Down
4 changes: 3 additions & 1 deletion src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,9 @@ class YarnRope(BaseModel):
rope_interleave: bool = Field(True, description="Whether RoPE sin/cos are interleaved vs concatenated.")
rope_truncate: bool = Field(True, description="Whether to floor/ceil the correction range for YaRN.")
rope_attention_scaling: bool = Field(
False, description="Scale the rotary embedding output. Used by some models like gpt-oss."
False, description="Scale the rotary embedding output. Used by some models like gpt-oss.")
rope_use_rotation_matrix: bool = Field(
False, description="Whether to use a rotation matrix for YaRN. Can only be use with interleave=True"
)


Expand Down
1 change: 1 addition & 0 deletions src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,7 @@ def init_rotary_embedding(self):
interleave=self.config.rope_interleave,
truncate=self.config.rope_truncate,
attention_scaling=self.config.rope_attention_scaling,
use_rotation_matrix=self.config.rope_use_rotation_matrix,
rngs=self.rngs,
)
elif self.is_qwen3_next:
Expand Down
76 changes: 49 additions & 27 deletions src/MaxText/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ class YarnRotaryEmbedding(nnx.Module):
rope_interleave: Whether complex representation is interleaved or concatenated.
rope_truncate: Whether or not to floor lower bound and ceil upper bound for correction range.
rope_attention_scaling: Whether or not to scale the rotary embedding output.
rope_use_rotation_matrix: Whether or not to use a rotation matrix for YaRN.
rngs: rng keys passed in by nnx.bridge.to_linen.
"""

Expand All @@ -708,6 +709,7 @@ def __init__(
interleave=True,
truncate=True,
attention_scaling=False,
use_rotation_matrix=False,
# Not used in YarnRotaryEmbedding but passed in by nnx.bridge.to_linen.
# TODO: Remove when bridge no longer needed
rngs: nnx.Rngs = None,
Expand All @@ -725,12 +727,28 @@ def __init__(
self.interleave = interleave
self.truncate = truncate
self.attention_scaling = attention_scaling
self.use_rotation_matrix = use_rotation_matrix

if use_rotation_matrix:
if not interleave:
raise ValueError("Using rotation matrix is only supported with interleave=True.")
self.pairwise_swap_and_negate_mask = self._init_pairwise_swap_and_negate_mask()

if self.embedding_dims % 2:
raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.")

def _init_pairwise_swap_and_negate_mask(self):
"""A mask to swap elements pairwise and negate even-indexed elements."""
indices = jnp.arange(self.embedding_dims)
# [1, 0, 3, 2, 5, 4, ...]
swap_indices = jnp.where(indices % 2 == 0, indices + 1, indices - 1)
negation_mask = jnp.where(indices % 2 == 0, -1, 1)

identity = jnp.eye(self.embedding_dims, dtype=jnp.int32)
return identity[swap_indices] * negation_mask

@property
def freqs_cis(self):
def freqs(self):
"""Frequencies for rotary embedding."""
half_dim = self.embedding_dims // 2
# Compute base frequencies for each (even-indexed) dimension.
Expand All @@ -752,10 +770,7 @@ def freqs_cis(self):
# Precompute frequencies for all positions by taking the outer product.
t = jnp.arange(self.max_position_embeddings, dtype=jnp.float32) # shape [max_position_embeddings]
# This gives a [max_position_embeddings, half_dim] tensor with rows as time steps.
freqs = jnp.outer(t, freqs)

# Compute the complex “cis” values: exp(i * theta).
return jnp.exp(1j * freqs) # shape [max_position_embeddings, half_dim]
return jnp.outer(t, freqs)

def _find_correction_dim(self, num_rotations: float, dim: int, base: float, max_position_embeddings: int) -> float:
"""Compute the correction dimension for a given number of rotations."""
Expand Down Expand Up @@ -802,6 +817,32 @@ def _linear_ramp_factor(self, min_val: float, max_val: float, dim: int) -> Array
linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_val) / (max_val - min_val)
return jnp.clip(linear_func, 0, 1)

def _perform_rotation(self, inputs: Array, freqs: Array) -> Array:
"""Performs the rotation of the inputs using the precomputed frequencies."""
if self.interleave and self.use_rotation_matrix:
# Decomposed rotation matrix into cosine and sine elementwise multiplication.
freqs = jnp.repeat(freqs, 2, axis=-1)
return inputs * jnp.cos(freqs) + jnp.matmul(inputs, self.pairwise_swap_and_negate_mask) * jnp.sin(freqs)

if self.interleave:
# Inputs with interleaved format [real1, img1, real2, img2, ...] at last dimension
# Convert the last dimension into a complex representation.
# First reshape so that each pair of numbers represents the real and imaginary parts.
B, S, N, H = inputs.shape
inputs_reshaped = inputs.reshape(B, S, N, H // 2, 2)
first_half, second_half = inputs_reshaped[..., 0], inputs_reshaped[..., 1]
else:
# Inputs with concatenated format [real1, real2, ..., img1, img2, ...] at last dimension
first_half, second_half = jnp.split(inputs, 2, axis=-1)

inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim]
# Apply the rotary transformation via complex multiplication.
rotated = inputs_complex * jnp.exp(1j * freqs) # shape: [B, S, N, half_dim]
# Convert the complex result back to a real tensor.
# Split the complex number into its real and imaginary parts.
# [real1, real2, ..., img1, img2, ...]
return jnp.concatenate([jnp.real(rotated), jnp.imag(rotated)], axis=-1)

def __call__(self, inputs: Array, position: None | Array = None) -> Array:
"""Applies the rotary positional embedding using the precomputed complex frequencies.

Expand All @@ -827,30 +868,11 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
position = position.astype(jnp.int32)

# Lookup the precomputed frequencies using the position indices.
# self.freqs_cis has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0.
# self.freqs has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0.
# After indexing, shape becomes [B, S, half_dim]; we then add an axis for the heads.
freqs = jnp.take(self.freqs_cis, position, axis=0) # shape: [B, S, half_dim]
freqs = jnp.take(self.freqs, position, axis=0) # shape: [B, S, half_dim]
freqs = freqs[:, :, jnp.newaxis, :] # shape: [B, S, 1, half_dim]

if self.interleave:
# Inputs with interleaved format [real1, img1, real2, img2, ...] at last dimension
# Convert the last dimension into a complex representation.
# First reshape so that each pair of numbers represents the real and imaginary parts.
B, S, N, H = inputs.shape
half_dim = H // 2
inputs_reshaped = inputs.reshape(B, S, N, half_dim, 2)
first_half, second_half = inputs_reshaped[..., 0], inputs_reshaped[..., 1]
else:
# Inputs with concatenated format [real1, real2, ..., img1, img2, ...] at last dimension
first_half, second_half = jnp.split(inputs, 2, axis=-1)

inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim]
# Apply the rotary transformation via complex multiplication.
rotated = inputs_complex * freqs # shape: [B, S, N, half_dim]
# Convert the complex result back to a real tensor.
# Split the complex number into its real and imaginary parts.
# [real1, real2, ..., img1, img2, ...]
output = jnp.concatenate([jnp.real(rotated), jnp.imag(rotated)], axis=-1)
output = self._perform_rotation(inputs, freqs)

if self.attention_scaling:
attention_scaling = 1.0 if self.rope_factor <= 1 else (0.1 * math.log(self.rope_factor) + 1.0)
Expand Down
Loading