diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 64455a118..b2bd8c7b8 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -725,6 +725,7 @@ mscale: 1.0 rope_interleave: True # RoPE with sin/cos interleaved vs concatenated rope_truncate: True # Floor lower bound and ceil upper bound for correction range rope_attention_scaling: False # Scale the rotary embedding output +rope_use_rotation_matrix: False # Whether to use rotation matrix for YaRN. # Ahead of time Compilation (aka AOT) # Only set these arguments if you are running train_compile or loading a compiled train step. diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 1dc615c9c..36f373c7a 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -995,7 +995,9 @@ class YarnRope(BaseModel): rope_interleave: bool = Field(True, description="Whether RoPE sin/cos are interleaved vs concatenated.") rope_truncate: bool = Field(True, description="Whether to floor/ceil the correction range for YaRN.") rope_attention_scaling: bool = Field( - False, description="Scale the rotary embedding output. Used by some models like gpt-oss." + False, description="Scale the rotary embedding output. Used by some models like gpt-oss.") + rope_use_rotation_matrix: bool = Field( + False, description="Whether to use a rotation matrix for YaRN. Can only be use with interleave=True" ) diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index 0586953de..52f436793 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -766,6 +766,7 @@ def init_rotary_embedding(self): interleave=self.config.rope_interleave, truncate=self.config.rope_truncate, attention_scaling=self.config.rope_attention_scaling, + use_rotation_matrix=self.config.rope_use_rotation_matrix, rngs=self.rngs, ) elif self.is_qwen3_next: diff --git a/src/MaxText/layers/embeddings.py b/src/MaxText/layers/embeddings.py index 76fb1b8a7..0ba77a960 100644 --- a/src/MaxText/layers/embeddings.py +++ b/src/MaxText/layers/embeddings.py @@ -691,6 +691,7 @@ class YarnRotaryEmbedding(nnx.Module): rope_interleave: Whether complex representation is interleaved or concatenated. rope_truncate: Whether or not to floor lower bound and ceil upper bound for correction range. rope_attention_scaling: Whether or not to scale the rotary embedding output. + rope_use_rotation_matrix: Whether or not to use a rotation matrix for YaRN. rngs: rng keys passed in by nnx.bridge.to_linen. """ @@ -708,6 +709,7 @@ def __init__( interleave=True, truncate=True, attention_scaling=False, + use_rotation_matrix=False, # Not used in YarnRotaryEmbedding but passed in by nnx.bridge.to_linen. # TODO: Remove when bridge no longer needed rngs: nnx.Rngs = None, @@ -725,12 +727,28 @@ def __init__( self.interleave = interleave self.truncate = truncate self.attention_scaling = attention_scaling + self.use_rotation_matrix = use_rotation_matrix + + if use_rotation_matrix: + if not interleave: + raise ValueError("Using rotation matrix is only supported with interleave=True.") + self.pairwise_swap_and_negate_mask = self._init_pairwise_swap_and_negate_mask() if self.embedding_dims % 2: raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") + def _init_pairwise_swap_and_negate_mask(self): + """A mask to swap elements pairwise and negate even-indexed elements.""" + indices = jnp.arange(self.embedding_dims) + # [1, 0, 3, 2, 5, 4, ...] + swap_indices = jnp.where(indices % 2 == 0, indices + 1, indices - 1) + negation_mask = jnp.where(indices % 2 == 0, -1, 1) + + identity = jnp.eye(self.embedding_dims, dtype=jnp.int32) + return identity[swap_indices] * negation_mask + @property - def freqs_cis(self): + def freqs(self): """Frequencies for rotary embedding.""" half_dim = self.embedding_dims // 2 # Compute base frequencies for each (even-indexed) dimension. @@ -752,10 +770,7 @@ def freqs_cis(self): # Precompute frequencies for all positions by taking the outer product. t = jnp.arange(self.max_position_embeddings, dtype=jnp.float32) # shape [max_position_embeddings] # This gives a [max_position_embeddings, half_dim] tensor with rows as time steps. - freqs = jnp.outer(t, freqs) - - # Compute the complex “cis” values: exp(i * theta). - return jnp.exp(1j * freqs) # shape [max_position_embeddings, half_dim] + return jnp.outer(t, freqs) def _find_correction_dim(self, num_rotations: float, dim: int, base: float, max_position_embeddings: int) -> float: """Compute the correction dimension for a given number of rotations.""" @@ -802,6 +817,32 @@ def _linear_ramp_factor(self, min_val: float, max_val: float, dim: int) -> Array linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_val) / (max_val - min_val) return jnp.clip(linear_func, 0, 1) + def _perform_rotation(self, inputs: Array, freqs: Array) -> Array: + """Performs the rotation of the inputs using the precomputed frequencies.""" + if self.interleave and self.use_rotation_matrix: + # Decomposed rotation matrix into cosine and sine elementwise multiplication. + freqs = jnp.repeat(freqs, 2, axis=-1) + return inputs * jnp.cos(freqs) + jnp.matmul(inputs, self.pairwise_swap_and_negate_mask) * jnp.sin(freqs) + + if self.interleave: + # Inputs with interleaved format [real1, img1, real2, img2, ...] at last dimension + # Convert the last dimension into a complex representation. + # First reshape so that each pair of numbers represents the real and imaginary parts. + B, S, N, H = inputs.shape + inputs_reshaped = inputs.reshape(B, S, N, H // 2, 2) + first_half, second_half = inputs_reshaped[..., 0], inputs_reshaped[..., 1] + else: + # Inputs with concatenated format [real1, real2, ..., img1, img2, ...] at last dimension + first_half, second_half = jnp.split(inputs, 2, axis=-1) + + inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim] + # Apply the rotary transformation via complex multiplication. + rotated = inputs_complex * jnp.exp(1j * freqs) # shape: [B, S, N, half_dim] + # Convert the complex result back to a real tensor. + # Split the complex number into its real and imaginary parts. + # [real1, real2, ..., img1, img2, ...] + return jnp.concatenate([jnp.real(rotated), jnp.imag(rotated)], axis=-1) + def __call__(self, inputs: Array, position: None | Array = None) -> Array: """Applies the rotary positional embedding using the precomputed complex frequencies. @@ -827,30 +868,11 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array: position = position.astype(jnp.int32) # Lookup the precomputed frequencies using the position indices. - # self.freqs_cis has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0. + # self.freqs has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0. # After indexing, shape becomes [B, S, half_dim]; we then add an axis for the heads. - freqs = jnp.take(self.freqs_cis, position, axis=0) # shape: [B, S, half_dim] + freqs = jnp.take(self.freqs, position, axis=0) # shape: [B, S, half_dim] freqs = freqs[:, :, jnp.newaxis, :] # shape: [B, S, 1, half_dim] - - if self.interleave: - # Inputs with interleaved format [real1, img1, real2, img2, ...] at last dimension - # Convert the last dimension into a complex representation. - # First reshape so that each pair of numbers represents the real and imaginary parts. - B, S, N, H = inputs.shape - half_dim = H // 2 - inputs_reshaped = inputs.reshape(B, S, N, half_dim, 2) - first_half, second_half = inputs_reshaped[..., 0], inputs_reshaped[..., 1] - else: - # Inputs with concatenated format [real1, real2, ..., img1, img2, ...] at last dimension - first_half, second_half = jnp.split(inputs, 2, axis=-1) - - inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim] - # Apply the rotary transformation via complex multiplication. - rotated = inputs_complex * freqs # shape: [B, S, N, half_dim] - # Convert the complex result back to a real tensor. - # Split the complex number into its real and imaginary parts. - # [real1, real2, ..., img1, img2, ...] - output = jnp.concatenate([jnp.real(rotated), jnp.imag(rotated)], axis=-1) + output = self._perform_rotation(inputs, freqs) if self.attention_scaling: attention_scaling = 1.0 if self.rope_factor <= 1 else (0.1 * math.log(self.rope_factor) + 1.0)