|
| 1 | +import math |
| 2 | + |
| 3 | +import keras |
| 4 | +from keras import ops |
| 5 | + |
| 6 | +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding |
| 7 | +from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm |
| 8 | +from keras_hub.src.utils.keras_utils import clone_initializer |
| 9 | +from keras_hub.src.utils.keras_utils import fused_attention_op_available |
| 10 | + |
| 11 | + |
| 12 | +class Qwen3MoeAttention(keras.layers.Layer): |
| 13 | + """A multi-head attention layer for Qwen3Moe models |
| 14 | + This attention implementation supports grouped-query attention (GQA) where |
| 15 | + the number of key-value heads can be less than the number of query heads. |
| 16 | +
|
| 17 | + Args: |
| 18 | + num_query_heads: int. Number of query heads. |
| 19 | + num_key_value_heads: int. Number of key/value heads (for GQA). |
| 20 | + head_dim: int. The dimension of each attention head. |
| 21 | + rope_max_wavelength: int. Maximum wavelength for RoPE (Rotary Position |
| 22 | + Embedding). |
| 23 | + rope_scaling_factor: float. Scaling factor for RoPE, used for extending |
| 24 | + context length. |
| 25 | + kernel_initializer: Initializer for the kernel weights. |
| 26 | + dropout: float. Dropout rate for attention weights. |
| 27 | + layer_norm_epsilon: float. The epsilon value for layer normalization. |
| 28 | + sliding_window_size: int. Size of the sliding window for attention. |
| 29 | + **kwargs: Additional keyword arguments to pass to the Layer. |
| 30 | + """ |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + num_query_heads, |
| 35 | + num_key_value_heads, |
| 36 | + head_dim=None, |
| 37 | + rope_max_wavelength=10000, |
| 38 | + rope_scaling_factor=1, |
| 39 | + kernel_initializer="glorot_uniform", |
| 40 | + dropout=0.0, |
| 41 | + layer_norm_epsilon=1e-6, |
| 42 | + sliding_window_size=None, |
| 43 | + **kwargs, |
| 44 | + ): |
| 45 | + super().__init__( |
| 46 | + **kwargs, |
| 47 | + ) |
| 48 | + self.num_query_heads = num_query_heads |
| 49 | + self.num_key_value_heads = num_key_value_heads |
| 50 | + self.head_dim = head_dim |
| 51 | + self.dropout = dropout |
| 52 | + |
| 53 | + self.layer_norm_epsilon = layer_norm_epsilon |
| 54 | + |
| 55 | + self.num_key_value_groups = num_query_heads // num_key_value_heads |
| 56 | + self.rope_max_wavelength = rope_max_wavelength |
| 57 | + |
| 58 | + self.kernel_initializer = keras.initializers.get( |
| 59 | + clone_initializer(kernel_initializer) |
| 60 | + ) |
| 61 | + |
| 62 | + self.rope_scaling_factor = rope_scaling_factor |
| 63 | + self.sliding_window_size = sliding_window_size |
| 64 | + |
| 65 | + def build(self, inputs_shape): |
| 66 | + # Einsum variables: |
| 67 | + # b = batch size |
| 68 | + # q = query length |
| 69 | + # k = key/value length |
| 70 | + # m = model dim |
| 71 | + # u = num query heads |
| 72 | + # v = num key/value heads |
| 73 | + # h = head dim |
| 74 | + hidden_dim = inputs_shape[-1] |
| 75 | + if not self.head_dim: |
| 76 | + self.head_dim = hidden_dim // self.num_query_heads |
| 77 | + |
| 78 | + self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim) |
| 79 | + self._query_dense = keras.layers.EinsumDense( |
| 80 | + equation="bqm,muh->bquh", |
| 81 | + output_shape=(None, self.num_query_heads, self.head_dim), |
| 82 | + kernel_initializer=self.kernel_initializer, |
| 83 | + dtype=self.dtype_policy, |
| 84 | + name="query", |
| 85 | + ) |
| 86 | + self._query_dense.build(inputs_shape) |
| 87 | + |
| 88 | + self._query_dense_layer_norm = Qwen3MoeLayerNorm( |
| 89 | + epsilon=self.layer_norm_epsilon, |
| 90 | + dtype=self.dtype_policy, |
| 91 | + head_dim=self.head_dim, |
| 92 | + name="query_dense_layernorm", |
| 93 | + ) |
| 94 | + self._query_dense_layer_norm.build(inputs_shape) |
| 95 | + |
| 96 | + self._key_dense = keras.layers.EinsumDense( |
| 97 | + equation="bkm,mvh->bkvh", |
| 98 | + output_shape=( |
| 99 | + None, |
| 100 | + self.num_key_value_heads, |
| 101 | + self.head_dim, |
| 102 | + ), |
| 103 | + kernel_initializer=self.kernel_initializer, |
| 104 | + dtype=self.dtype_policy, |
| 105 | + name="key", |
| 106 | + ) |
| 107 | + self._key_dense.build(inputs_shape) |
| 108 | + |
| 109 | + self._key_dense_layer_norm = Qwen3MoeLayerNorm( |
| 110 | + epsilon=self.layer_norm_epsilon, |
| 111 | + dtype=self.dtype_policy, |
| 112 | + head_dim=self.head_dim, |
| 113 | + name="key_dense_layernorm", |
| 114 | + ) |
| 115 | + self._key_dense_layer_norm.build(inputs_shape) |
| 116 | + |
| 117 | + self._value_dense = keras.layers.EinsumDense( |
| 118 | + equation="bkm,mvh->bkvh", |
| 119 | + output_shape=( |
| 120 | + None, |
| 121 | + self.num_key_value_heads, |
| 122 | + self.head_dim, |
| 123 | + ), |
| 124 | + kernel_initializer=self.kernel_initializer, |
| 125 | + dtype=self.dtype_policy, |
| 126 | + name="value", |
| 127 | + ) |
| 128 | + self._value_dense.build(inputs_shape) |
| 129 | + |
| 130 | + self._softmax = keras.layers.Softmax( |
| 131 | + axis=-1, |
| 132 | + dtype="float32", |
| 133 | + name="attention_softmax", |
| 134 | + ) |
| 135 | + |
| 136 | + self._dropout_layer = keras.layers.Dropout( |
| 137 | + rate=self.dropout, |
| 138 | + dtype=self.dtype_policy, |
| 139 | + ) |
| 140 | + |
| 141 | + self._output_dense = keras.layers.EinsumDense( |
| 142 | + equation="bquh,uhm->bqm", |
| 143 | + output_shape=(None, hidden_dim), |
| 144 | + kernel_initializer=self.kernel_initializer, |
| 145 | + dtype=self.dtype_policy, |
| 146 | + name="attention_output", |
| 147 | + ) |
| 148 | + self._output_dense.build( |
| 149 | + (None, None, self.num_query_heads, self.head_dim) |
| 150 | + ) |
| 151 | + |
| 152 | + self.rotary_embedding_layer = RotaryEmbedding( |
| 153 | + max_wavelength=self.rope_max_wavelength, |
| 154 | + scaling_factor=self.rope_scaling_factor, |
| 155 | + dtype=self.dtype_policy, |
| 156 | + ) |
| 157 | + |
| 158 | + self._dot_product_equation = "bquh,bkuh->buqk" |
| 159 | + self._combine_equation = "buqk,bkuh->bquh" |
| 160 | + |
| 161 | + self.built = True |
| 162 | + |
| 163 | + def call( |
| 164 | + self, |
| 165 | + hidden_states, |
| 166 | + attention_mask=None, |
| 167 | + cache=None, |
| 168 | + cache_update_index=None, |
| 169 | + training=None, |
| 170 | + ): |
| 171 | + """Applies attention mechanism to the input hidden states. |
| 172 | +
|
| 173 | + Args: |
| 174 | + hidden_states: Input tensor of shape [batch_size, seq_length, |
| 175 | + hidden_size]. |
| 176 | + attention_mask: Mask tensor of shape [batch_size, seq_length, |
| 177 | + seq_length]. |
| 178 | + cache: Optional cached key and value tensors. |
| 179 | + cache_update_index: Index at which to update the cache. |
| 180 | + training: Boolean indicating whether in training mode. |
| 181 | +
|
| 182 | + Returns: |
| 183 | + attention_output: Output tensor after applying attention. |
| 184 | + cache: Updated cache tensors (if cache is provided). |
| 185 | + """ |
| 186 | + start_index = ( |
| 187 | + cache_update_index if cache_update_index is not None else 0 |
| 188 | + ) |
| 189 | + |
| 190 | + query = self._query_dense(hidden_states) |
| 191 | + query = self._query_dense_layer_norm(query) |
| 192 | + |
| 193 | + # Compute RoPE for queries |
| 194 | + query = self.rotary_embedding_layer(query, start_index=start_index) |
| 195 | + |
| 196 | + def _compute_key_value(x): |
| 197 | + key = self._key_dense(x) |
| 198 | + key = self._key_dense_layer_norm(key) |
| 199 | + key = self.rotary_embedding_layer(key, start_index=start_index) |
| 200 | + |
| 201 | + value = self._value_dense(x) |
| 202 | + |
| 203 | + return key, value |
| 204 | + |
| 205 | + if cache is not None: |
| 206 | + key_cache = cache[:, 0, ...] |
| 207 | + value_cache = cache[:, 1, ...] |
| 208 | + if cache_update_index is None: |
| 209 | + key = key_cache |
| 210 | + value = value_cache |
| 211 | + else: |
| 212 | + key_update, value_update = _compute_key_value(hidden_states) |
| 213 | + start = [0, cache_update_index, 0, 0] |
| 214 | + key = ops.slice_update(key_cache, start, key_update) |
| 215 | + value = ops.slice_update(value_cache, start, value_update) |
| 216 | + cache = ops.stack((key, value), axis=1) |
| 217 | + else: |
| 218 | + if cache_update_index is not None: |
| 219 | + raise ValueError( |
| 220 | + "`cache_update_index` should not be set if `cache` is " |
| 221 | + f"`None`. Received: cache={cache}, " |
| 222 | + f"cache_update_index={cache_update_index}" |
| 223 | + ) |
| 224 | + key, value = _compute_key_value(hidden_states) |
| 225 | + |
| 226 | + # [batch_shape, seq_len, num_key_value_heads, head_dim] |
| 227 | + # -> [batch_shape, seq_len, num_heads, head_dim] |
| 228 | + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) |
| 229 | + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) |
| 230 | + |
| 231 | + attention_output = self._compute_attention( |
| 232 | + query, |
| 233 | + key, |
| 234 | + value, |
| 235 | + attention_mask, |
| 236 | + cache_update_index=cache_update_index, |
| 237 | + ) |
| 238 | + |
| 239 | + attention_output = self._dropout_layer( |
| 240 | + attention_output, training=training |
| 241 | + ) |
| 242 | + |
| 243 | + attention_output = self._output_dense(attention_output) |
| 244 | + |
| 245 | + if cache is not None: |
| 246 | + return attention_output, cache |
| 247 | + return attention_output |
| 248 | + |
| 249 | + def _masked_softmax(self, attention_scores, attention_mask=None): |
| 250 | + """Applies softmax with optional masking. |
| 251 | +
|
| 252 | + Args: |
| 253 | + attention_scores: Attention score tensor. |
| 254 | + attention_mask: Optional mask tensor. |
| 255 | +
|
| 256 | + Returns: |
| 257 | + Masked softmax attention weights. |
| 258 | + """ |
| 259 | + if attention_mask is not None: |
| 260 | + return self._softmax( |
| 261 | + attention_scores, attention_mask[:, None, :, :] |
| 262 | + ) |
| 263 | + return self._softmax(attention_scores) |
| 264 | + |
| 265 | + def _compute_attention( |
| 266 | + self, query, key, value, attention_mask=None, cache_update_index=None |
| 267 | + ): |
| 268 | + """Computes attention using query, key, and value tensors. |
| 269 | + Uses Flash Attention when available for better performance. |
| 270 | +
|
| 271 | + Args: |
| 272 | + query: Query tensor. |
| 273 | + key: Key tensor. |
| 274 | + value: Value tensor. |
| 275 | + attention_mask: Optional mask tensor. |
| 276 | + cache_update_index: Index for sliding window computation. |
| 277 | +
|
| 278 | + Returns: |
| 279 | + attention_output: Output tensor after applying attention. |
| 280 | + """ |
| 281 | + if fused_attention_op_available(): |
| 282 | + # Use `dot_product_attention` with Flash Attention support if |
| 283 | + # available. |
| 284 | + if attention_mask is not None: |
| 285 | + attention_mask = ops.expand_dims(attention_mask, axis=1) |
| 286 | + attention_mask = ops.cast(attention_mask, dtype="bool") |
| 287 | + attention_output = ops.dot_product_attention( |
| 288 | + query, |
| 289 | + key, |
| 290 | + value, |
| 291 | + mask=attention_mask, |
| 292 | + scale=self._inv_norm_factor, |
| 293 | + ) |
| 294 | + return attention_output |
| 295 | + |
| 296 | + attention_scores = ops.einsum(self._dot_product_equation, query, key) |
| 297 | + |
| 298 | + attention_scores = ops.multiply( |
| 299 | + attention_scores, |
| 300 | + ops.cast(self._inv_norm_factor, self.compute_dtype), |
| 301 | + ) |
| 302 | + if self.sliding_window_size: |
| 303 | + attention_mask = self._mask_sliding_window( |
| 304 | + attention_mask, |
| 305 | + cache_update_index=cache_update_index |
| 306 | + if cache_update_index is not None |
| 307 | + else 0, |
| 308 | + ) |
| 309 | + attention_scores = self._masked_softmax( |
| 310 | + attention_scores, attention_mask |
| 311 | + ) |
| 312 | + attention_scores = ops.cast(attention_scores, self.compute_dtype) |
| 313 | + attention_output = ops.einsum( |
| 314 | + self._combine_equation, attention_scores, value |
| 315 | + ) |
| 316 | + |
| 317 | + return attention_output |
| 318 | + |
| 319 | + def _mask_sliding_window( |
| 320 | + self, |
| 321 | + attention_mask, |
| 322 | + cache_update_index=0, |
| 323 | + ): |
| 324 | + """Creates and combines a sliding window mask with the attention mask. |
| 325 | +
|
| 326 | + Args: |
| 327 | + attention_mask: Original attention mask. |
| 328 | + cache_update_index: Starting index for the sliding window. |
| 329 | +
|
| 330 | + Returns: |
| 331 | + Combined attention mask with sliding window constraints. |
| 332 | + """ |
| 333 | + _, query_len, key_len = ops.shape(attention_mask) |
| 334 | + # Compute the sliding window for square attention. |
| 335 | + all_ones = ops.ones((key_len, key_len), "bool") |
| 336 | + if keras.config.backend() == "tensorflow": |
| 337 | + # TODO: trui/tril has issues with dynamic shape on the tensorflow |
| 338 | + # backend. We should fix, but use `band_part` for now. |
| 339 | + import tensorflow as tf |
| 340 | + |
| 341 | + band_size = ops.minimum(key_len, self.sliding_window_size - 1) |
| 342 | + band_size = ops.cast(band_size, "int32") |
| 343 | + sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) |
| 344 | + else: |
| 345 | + sliding_mask = ops.triu( |
| 346 | + all_ones, -1 * self.sliding_window_size + 1 |
| 347 | + ) * ops.tril(all_ones, self.sliding_window_size - 1) |
| 348 | + # Slice the window for short queries during generation. |
| 349 | + start = (cache_update_index, 0) |
| 350 | + sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) |
| 351 | + sliding_mask = ops.expand_dims(sliding_mask, 0) |
| 352 | + return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) |
| 353 | + |
| 354 | + def get_config(self): |
| 355 | + config = super().get_config() |
| 356 | + config.update( |
| 357 | + { |
| 358 | + "num_query_heads": self.num_query_heads, |
| 359 | + "num_key_value_heads": self.num_key_value_heads, |
| 360 | + "rope_max_wavelength": self.rope_max_wavelength, |
| 361 | + "rope_scaling_factor": self.rope_scaling_factor, |
| 362 | + "kernel_initializer": keras.initializers.serialize( |
| 363 | + self.kernel_initializer |
| 364 | + ), |
| 365 | + "dropout": self.dropout, |
| 366 | + "sliding_window_size": self.sliding_window_size, |
| 367 | + "head_dim": self.head_dim, |
| 368 | + "layer_norm_epsilon": self.layer_norm_epsilon, |
| 369 | + } |
| 370 | + ) |
| 371 | + return config |
0 commit comments