Skip to content

Commit cd943b1

Browse files
Add Qwen3 Moe (#2260)
* qwen3 moe init * bug fixes * update * address comments * address comments * update output matching script * fix test * update qwen3 causal lm * output matching + bug fixes * uncomment flag in conversion script * api gen * address comments * chore: address comment * code format + one comment * address comments * add causal lm + preprocessor tests * chore: address feedback * test fix * keras_hub/src/utils/transformers/convert_qwen3_moe_test.py
1 parent 507f852 commit cd943b1

16 files changed

+2567
-0
lines changed

keras_hub/api/models/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,15 @@
527527
from keras_hub.src.models.qwen3.qwen3_tokenizer import (
528528
Qwen3Tokenizer as Qwen3Tokenizer,
529529
)
530+
from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import (
531+
Qwen3MoeBackbone as Qwen3MoeBackbone,
532+
)
533+
from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm import (
534+
Qwen3MoeCausalLM as Qwen3MoeCausalLM,
535+
)
536+
from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm_preprocessor import (
537+
Qwen3MoeCausalLMPreprocessor as Qwen3MoeCausalLMPreprocessor,
538+
)
530539
from keras_hub.src.models.qwen_moe.qwen_moe_backbone import (
531540
QwenMoeBackbone as QwenMoeBackbone,
532541
)

keras_hub/api/tokenizers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@
7878
from keras_hub.src.models.qwen.qwen_tokenizer import (
7979
QwenTokenizer as QwenTokenizer,
8080
)
81+
from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import (
82+
Qwen3MoeTokenizer as Qwen3MoeTokenizer,
83+
)
8184
from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import (
8285
QwenMoeTokenizer as QwenMoeTokenizer,
8386
)
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
import math
2+
3+
import keras
4+
from keras import ops
5+
6+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7+
from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm
8+
from keras_hub.src.utils.keras_utils import clone_initializer
9+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
10+
11+
12+
class Qwen3MoeAttention(keras.layers.Layer):
13+
"""A multi-head attention layer for Qwen3Moe models
14+
This attention implementation supports grouped-query attention (GQA) where
15+
the number of key-value heads can be less than the number of query heads.
16+
17+
Args:
18+
num_query_heads: int. Number of query heads.
19+
num_key_value_heads: int. Number of key/value heads (for GQA).
20+
head_dim: int. The dimension of each attention head.
21+
rope_max_wavelength: int. Maximum wavelength for RoPE (Rotary Position
22+
Embedding).
23+
rope_scaling_factor: float. Scaling factor for RoPE, used for extending
24+
context length.
25+
kernel_initializer: Initializer for the kernel weights.
26+
dropout: float. Dropout rate for attention weights.
27+
layer_norm_epsilon: float. The epsilon value for layer normalization.
28+
sliding_window_size: int. Size of the sliding window for attention.
29+
**kwargs: Additional keyword arguments to pass to the Layer.
30+
"""
31+
32+
def __init__(
33+
self,
34+
num_query_heads,
35+
num_key_value_heads,
36+
head_dim=None,
37+
rope_max_wavelength=10000,
38+
rope_scaling_factor=1,
39+
kernel_initializer="glorot_uniform",
40+
dropout=0.0,
41+
layer_norm_epsilon=1e-6,
42+
sliding_window_size=None,
43+
**kwargs,
44+
):
45+
super().__init__(
46+
**kwargs,
47+
)
48+
self.num_query_heads = num_query_heads
49+
self.num_key_value_heads = num_key_value_heads
50+
self.head_dim = head_dim
51+
self.dropout = dropout
52+
53+
self.layer_norm_epsilon = layer_norm_epsilon
54+
55+
self.num_key_value_groups = num_query_heads // num_key_value_heads
56+
self.rope_max_wavelength = rope_max_wavelength
57+
58+
self.kernel_initializer = keras.initializers.get(
59+
clone_initializer(kernel_initializer)
60+
)
61+
62+
self.rope_scaling_factor = rope_scaling_factor
63+
self.sliding_window_size = sliding_window_size
64+
65+
def build(self, inputs_shape):
66+
# Einsum variables:
67+
# b = batch size
68+
# q = query length
69+
# k = key/value length
70+
# m = model dim
71+
# u = num query heads
72+
# v = num key/value heads
73+
# h = head dim
74+
hidden_dim = inputs_shape[-1]
75+
if not self.head_dim:
76+
self.head_dim = hidden_dim // self.num_query_heads
77+
78+
self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
79+
self._query_dense = keras.layers.EinsumDense(
80+
equation="bqm,muh->bquh",
81+
output_shape=(None, self.num_query_heads, self.head_dim),
82+
kernel_initializer=self.kernel_initializer,
83+
dtype=self.dtype_policy,
84+
name="query",
85+
)
86+
self._query_dense.build(inputs_shape)
87+
88+
self._query_dense_layer_norm = Qwen3MoeLayerNorm(
89+
epsilon=self.layer_norm_epsilon,
90+
dtype=self.dtype_policy,
91+
head_dim=self.head_dim,
92+
name="query_dense_layernorm",
93+
)
94+
self._query_dense_layer_norm.build(inputs_shape)
95+
96+
self._key_dense = keras.layers.EinsumDense(
97+
equation="bkm,mvh->bkvh",
98+
output_shape=(
99+
None,
100+
self.num_key_value_heads,
101+
self.head_dim,
102+
),
103+
kernel_initializer=self.kernel_initializer,
104+
dtype=self.dtype_policy,
105+
name="key",
106+
)
107+
self._key_dense.build(inputs_shape)
108+
109+
self._key_dense_layer_norm = Qwen3MoeLayerNorm(
110+
epsilon=self.layer_norm_epsilon,
111+
dtype=self.dtype_policy,
112+
head_dim=self.head_dim,
113+
name="key_dense_layernorm",
114+
)
115+
self._key_dense_layer_norm.build(inputs_shape)
116+
117+
self._value_dense = keras.layers.EinsumDense(
118+
equation="bkm,mvh->bkvh",
119+
output_shape=(
120+
None,
121+
self.num_key_value_heads,
122+
self.head_dim,
123+
),
124+
kernel_initializer=self.kernel_initializer,
125+
dtype=self.dtype_policy,
126+
name="value",
127+
)
128+
self._value_dense.build(inputs_shape)
129+
130+
self._softmax = keras.layers.Softmax(
131+
axis=-1,
132+
dtype="float32",
133+
name="attention_softmax",
134+
)
135+
136+
self._dropout_layer = keras.layers.Dropout(
137+
rate=self.dropout,
138+
dtype=self.dtype_policy,
139+
)
140+
141+
self._output_dense = keras.layers.EinsumDense(
142+
equation="bquh,uhm->bqm",
143+
output_shape=(None, hidden_dim),
144+
kernel_initializer=self.kernel_initializer,
145+
dtype=self.dtype_policy,
146+
name="attention_output",
147+
)
148+
self._output_dense.build(
149+
(None, None, self.num_query_heads, self.head_dim)
150+
)
151+
152+
self.rotary_embedding_layer = RotaryEmbedding(
153+
max_wavelength=self.rope_max_wavelength,
154+
scaling_factor=self.rope_scaling_factor,
155+
dtype=self.dtype_policy,
156+
)
157+
158+
self._dot_product_equation = "bquh,bkuh->buqk"
159+
self._combine_equation = "buqk,bkuh->bquh"
160+
161+
self.built = True
162+
163+
def call(
164+
self,
165+
hidden_states,
166+
attention_mask=None,
167+
cache=None,
168+
cache_update_index=None,
169+
training=None,
170+
):
171+
"""Applies attention mechanism to the input hidden states.
172+
173+
Args:
174+
hidden_states: Input tensor of shape [batch_size, seq_length,
175+
hidden_size].
176+
attention_mask: Mask tensor of shape [batch_size, seq_length,
177+
seq_length].
178+
cache: Optional cached key and value tensors.
179+
cache_update_index: Index at which to update the cache.
180+
training: Boolean indicating whether in training mode.
181+
182+
Returns:
183+
attention_output: Output tensor after applying attention.
184+
cache: Updated cache tensors (if cache is provided).
185+
"""
186+
start_index = (
187+
cache_update_index if cache_update_index is not None else 0
188+
)
189+
190+
query = self._query_dense(hidden_states)
191+
query = self._query_dense_layer_norm(query)
192+
193+
# Compute RoPE for queries
194+
query = self.rotary_embedding_layer(query, start_index=start_index)
195+
196+
def _compute_key_value(x):
197+
key = self._key_dense(x)
198+
key = self._key_dense_layer_norm(key)
199+
key = self.rotary_embedding_layer(key, start_index=start_index)
200+
201+
value = self._value_dense(x)
202+
203+
return key, value
204+
205+
if cache is not None:
206+
key_cache = cache[:, 0, ...]
207+
value_cache = cache[:, 1, ...]
208+
if cache_update_index is None:
209+
key = key_cache
210+
value = value_cache
211+
else:
212+
key_update, value_update = _compute_key_value(hidden_states)
213+
start = [0, cache_update_index, 0, 0]
214+
key = ops.slice_update(key_cache, start, key_update)
215+
value = ops.slice_update(value_cache, start, value_update)
216+
cache = ops.stack((key, value), axis=1)
217+
else:
218+
if cache_update_index is not None:
219+
raise ValueError(
220+
"`cache_update_index` should not be set if `cache` is "
221+
f"`None`. Received: cache={cache}, "
222+
f"cache_update_index={cache_update_index}"
223+
)
224+
key, value = _compute_key_value(hidden_states)
225+
226+
# [batch_shape, seq_len, num_key_value_heads, head_dim]
227+
# -> [batch_shape, seq_len, num_heads, head_dim]
228+
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
229+
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
230+
231+
attention_output = self._compute_attention(
232+
query,
233+
key,
234+
value,
235+
attention_mask,
236+
cache_update_index=cache_update_index,
237+
)
238+
239+
attention_output = self._dropout_layer(
240+
attention_output, training=training
241+
)
242+
243+
attention_output = self._output_dense(attention_output)
244+
245+
if cache is not None:
246+
return attention_output, cache
247+
return attention_output
248+
249+
def _masked_softmax(self, attention_scores, attention_mask=None):
250+
"""Applies softmax with optional masking.
251+
252+
Args:
253+
attention_scores: Attention score tensor.
254+
attention_mask: Optional mask tensor.
255+
256+
Returns:
257+
Masked softmax attention weights.
258+
"""
259+
if attention_mask is not None:
260+
return self._softmax(
261+
attention_scores, attention_mask[:, None, :, :]
262+
)
263+
return self._softmax(attention_scores)
264+
265+
def _compute_attention(
266+
self, query, key, value, attention_mask=None, cache_update_index=None
267+
):
268+
"""Computes attention using query, key, and value tensors.
269+
Uses Flash Attention when available for better performance.
270+
271+
Args:
272+
query: Query tensor.
273+
key: Key tensor.
274+
value: Value tensor.
275+
attention_mask: Optional mask tensor.
276+
cache_update_index: Index for sliding window computation.
277+
278+
Returns:
279+
attention_output: Output tensor after applying attention.
280+
"""
281+
if fused_attention_op_available():
282+
# Use `dot_product_attention` with Flash Attention support if
283+
# available.
284+
if attention_mask is not None:
285+
attention_mask = ops.expand_dims(attention_mask, axis=1)
286+
attention_mask = ops.cast(attention_mask, dtype="bool")
287+
attention_output = ops.dot_product_attention(
288+
query,
289+
key,
290+
value,
291+
mask=attention_mask,
292+
scale=self._inv_norm_factor,
293+
)
294+
return attention_output
295+
296+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
297+
298+
attention_scores = ops.multiply(
299+
attention_scores,
300+
ops.cast(self._inv_norm_factor, self.compute_dtype),
301+
)
302+
if self.sliding_window_size:
303+
attention_mask = self._mask_sliding_window(
304+
attention_mask,
305+
cache_update_index=cache_update_index
306+
if cache_update_index is not None
307+
else 0,
308+
)
309+
attention_scores = self._masked_softmax(
310+
attention_scores, attention_mask
311+
)
312+
attention_scores = ops.cast(attention_scores, self.compute_dtype)
313+
attention_output = ops.einsum(
314+
self._combine_equation, attention_scores, value
315+
)
316+
317+
return attention_output
318+
319+
def _mask_sliding_window(
320+
self,
321+
attention_mask,
322+
cache_update_index=0,
323+
):
324+
"""Creates and combines a sliding window mask with the attention mask.
325+
326+
Args:
327+
attention_mask: Original attention mask.
328+
cache_update_index: Starting index for the sliding window.
329+
330+
Returns:
331+
Combined attention mask with sliding window constraints.
332+
"""
333+
_, query_len, key_len = ops.shape(attention_mask)
334+
# Compute the sliding window for square attention.
335+
all_ones = ops.ones((key_len, key_len), "bool")
336+
if keras.config.backend() == "tensorflow":
337+
# TODO: trui/tril has issues with dynamic shape on the tensorflow
338+
# backend. We should fix, but use `band_part` for now.
339+
import tensorflow as tf
340+
341+
band_size = ops.minimum(key_len, self.sliding_window_size - 1)
342+
band_size = ops.cast(band_size, "int32")
343+
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
344+
else:
345+
sliding_mask = ops.triu(
346+
all_ones, -1 * self.sliding_window_size + 1
347+
) * ops.tril(all_ones, self.sliding_window_size - 1)
348+
# Slice the window for short queries during generation.
349+
start = (cache_update_index, 0)
350+
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
351+
sliding_mask = ops.expand_dims(sliding_mask, 0)
352+
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
353+
354+
def get_config(self):
355+
config = super().get_config()
356+
config.update(
357+
{
358+
"num_query_heads": self.num_query_heads,
359+
"num_key_value_heads": self.num_key_value_heads,
360+
"rope_max_wavelength": self.rope_max_wavelength,
361+
"rope_scaling_factor": self.rope_scaling_factor,
362+
"kernel_initializer": keras.initializers.serialize(
363+
self.kernel_initializer
364+
),
365+
"dropout": self.dropout,
366+
"sliding_window_size": self.sliding_window_size,
367+
"head_dim": self.head_dim,
368+
"layer_norm_epsilon": self.layer_norm_epsilon,
369+
}
370+
)
371+
return config

0 commit comments

Comments
 (0)