diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 308321717c..68215c9369 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -536,6 +536,15 @@ from keras_hub.src.models.qwen3_moe.qwen3_moe_causal_lm_preprocessor import ( Qwen3MoeCausalLMPreprocessor as Qwen3MoeCausalLMPreprocessor, ) +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_backbone import ( + Qwen3OmniMoeBackbone as Qwen3OmniMoeBackbone, +) +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_causal_lm import ( + Qwen3OmniMoeCausalLM as Qwen3OmniMoeCausalLM, +) +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_causal_lm_preprocessor import ( + Qwen3OmniMoeCausalLMPreprocessor as Qwen3OmniMoeCausalLMPreprocessor, +) from keras_hub.src.models.qwen_moe.qwen_moe_backbone import ( QwenMoeBackbone as QwenMoeBackbone, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index b155d0e6e1..af00f6a398 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -81,6 +81,9 @@ from keras_hub.src.models.qwen3_moe.qwen3_moe_tokenizer import ( Qwen3MoeTokenizer as Qwen3MoeTokenizer, ) +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_tokenizer import ( + Qwen3OmniMoeTokenizer as Qwen3OmniMoeTokenizer, +) from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import ( QwenMoeTokenizer as QwenMoeTokenizer, ) diff --git a/keras_hub/src/models/qwen3_omni_moe/__init__.py b/keras_hub/src/models/qwen3_omni_moe/__init__.py new file mode 100644 index 0000000000..051b2c6711 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/__init__.py @@ -0,0 +1,8 @@ +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_backbone import Qwen3OmniMoeBackbone +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_causal_lm import Qwen3OmniMoeCausalLM +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_causal_lm_preprocessor import Qwen3OmniMoeCausalLMPreprocessor +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_presets import backbone_presets +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_tokenizer import Qwen3OmniMoeTokenizer +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, Qwen3OmniMoeBackbone) diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_attention.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_attention.py new file mode 100644 index 0000000000..b046a7c486 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_attention.py @@ -0,0 +1,218 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_layernorm import Qwen3OmniMoeLayerNorm + + +@keras_hub_export("keras_hub.models.Qwen3OmniMoeAttention") +class Qwen3OmniMoeAttention(keras.layers.Layer): + """Multi-head attention for Qwen3-Omni MoE model. + + This layer implements multi-head attention with grouped query attention (GQA) + and rotary positional embeddings for the Qwen3-Omni MoE model. It supports + efficient key-value caching for autoregressive generation. + + Args: + num_query_heads: int. The number of heads for the query projections. + num_key_value_heads: int. The number of heads for the key and value + projections (must be <= num_query_heads). + hidden_dim: int. The size of the transformer hidden state. + head_dim: int, optional. The size of each attention head. If None, + defaults to hidden_dim // num_query_heads. + layer_norm_epsilon: float, default 1e-6. The epsilon value used for + layer normalization. + dropout: float, default 0.0. Dropout probability for attention weights. + sliding_window_size: int, default 4096. Size of the sliding local window. + max_sequence_length: int, default 32768. The maximum sequence length + supported by the model. + dtype: str or `keras.mixed_precision.DTypePolicy`, optional. The dtype + to use for the layer's computations and weights. + + Example: + ```python + # Create attention layer + attention = Qwen3OmniMoeAttention( + num_query_heads=32, + num_key_value_heads=4, + hidden_dim=4096, + head_dim=128 + ) + + # Apply to input + hidden_states = keras.random.normal((2, 10, 4096)) + outputs = attention(hidden_states) + # outputs["hidden_states"] shape: (2, 10, 4096) + # outputs["cache"] contains key-value cache for generation + ``` + """ + + def __init__( + self, + num_query_heads, + num_key_value_heads, + hidden_dim, + head_dim, + layer_norm_epsilon=1e-6, + dropout=0.0, + sliding_window_size=4096, + max_sequence_length=32768, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.head_dim = head_dim if head_dim is not None else hidden_dim // num_query_heads + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.sliding_window_size = sliding_window_size + self.max_sequence_length = max_sequence_length + + # Query projection + self.query_projection = keras.layers.Dense( + num_query_heads * self.head_dim, + use_bias=False, + dtype=dtype, + name="query_projection", + ) + + # Key projection + self.key_projection = keras.layers.Dense( + num_key_value_heads * self.head_dim, + use_bias=False, + dtype=dtype, + name="key_projection", + ) + + # Value projection + self.value_projection = keras.layers.Dense( + num_key_value_heads * self.head_dim, + use_bias=False, + dtype=dtype, + name="value_projection", + ) + + # Output projection + self.output_projection = keras.layers.Dense( + hidden_dim, + use_bias=False, + dtype=dtype, + name="output_projection", + ) + + # Rotary embedding + self.rotary_embedding = RotaryEmbedding( + max_wavelength=10000, + scaling_factor=1.0, + dtype=dtype, + name="rotary_embedding", + ) + + def call( + self, + hidden_states, + attention_mask=None, + position_ids=None, + cache=None, + cache_update_index=None, + training=None, + ): + batch_size, seq_len, hidden_dim = ops.shape(hidden_states) + + # Project to query, key, value + query = self.query_projection(hidden_states) + key = self.key_projection(hidden_states) + value = self.value_projection(hidden_states) + + # Reshape for multi-head attention + query = ops.reshape( + query, (batch_size, seq_len, self.num_query_heads, self.head_dim) + ) + key = ops.reshape( + key, (batch_size, seq_len, self.num_key_value_heads, self.head_dim) + ) + value = ops.reshape( + value, (batch_size, seq_len, self.num_key_value_heads, self.head_dim) + ) + + # Apply rotary embedding + if position_ids is not None: + query = self.rotary_embedding(query, position_ids) + key = self.rotary_embedding(key, position_ids) + + # Handle cache + if cache is not None and cache_update_index is not None: + # Update cache + key = ops.concatenate([cache["key"], key], axis=1) + value = ops.concatenate([cache["value"], value], axis=1) + + # Update cache + new_cache = { + "key": key, + "value": value, + } + + # Transpose for attention + query = ops.transpose(query, (0, 2, 1, 3)) # (batch_size, num_heads, seq_len, head_dim) + key = ops.transpose(key, (0, 2, 1, 3)) + value = ops.transpose(value, (0, 2, 1, 3)) + + # Handle grouped query attention (GQA) + # Repeat key and value for grouped query attention + if self.num_key_value_heads < self.num_query_heads: + num_groups = self.num_query_heads // self.num_key_value_heads + key = ops.repeat(key, num_groups, axis=1) + value = ops.repeat(value, num_groups, axis=1) + + # Compute attention scores + attention_scores = ops.matmul(query, ops.transpose(key, (0, 1, 3, 2))) + attention_scores = attention_scores / ops.sqrt(self.head_dim) + + # Apply attention mask + if attention_mask is not None: + if len(attention_mask.shape) == 2: + # Convert 2D mask to 4D for broadcasting + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_scores = ops.where( + attention_mask, attention_scores, ops.full_like(attention_scores, -1e9) + ) + + # Apply softmax + attention_weights = ops.softmax(attention_scores, axis=-1) + + # Apply attention to values + attention_output = ops.matmul(attention_weights, value) + + # Transpose back + attention_output = ops.transpose(attention_output, (0, 2, 1, 3)) + + # Reshape and project + attention_output = ops.reshape( + attention_output, (batch_size, seq_len, self.num_query_heads * self.head_dim) + ) + attention_output = self.output_projection(attention_output) + + return { + "hidden_states": attention_output, + "cache": new_cache, + } + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "head_dim": self.head_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "sliding_window_size": self.sliding_window_size, + "max_sequence_length": self.max_sequence_length, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_backbone.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_backbone.py new file mode 100644 index 0000000000..f795cd39ce --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_backbone.py @@ -0,0 +1,252 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_decoder import ( + Qwen3OmniMoeTransformerDecoder, +) +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_layernorm import Qwen3OmniMoeLayerNorm + + +def _qwen3_omni_moe_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_hub_export("keras_hub.models.Qwen3OmniMoeBackbone") +class Qwen3OmniMoeBackbone(Backbone): + """Qwen3-Omni MoE core network with multimodal capabilities. + + This backbone implements the base Transformer network for the Qwen3-Omni MoE + model. It includes embedding lookups and transformer layers with a Mixture + of Experts (MoE) architecture, supporting text, audio, and vision inputs. + This backbone outputs the final hidden states for each token, not generative + predictions over the vocabulary space. For higher-level object for text + generation, see `keras_hub.models.Qwen3OmniMoeCausalLM`. + + The default constructor gives a fully customizable, randomly initialized + Qwen3-Omni MoE model with any number of layers, heads, and embedding dimensions. + To load preset architectures and weights, use the `from_preset` constructor. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_query_heads: int. The number of heads for the query projections in + the attention layer. + num_key_value_heads: int. The number of heads for the key and value + projections in the attention layer. + hidden_dim: int. The size of the transformer hidden state at the end of + each transformer layer. + intermediate_dim: int. The output dimension of the first Dense layer in + the feedforward network for each transformer. + num_experts: int. The number of experts in each MoE layer. + num_experts_per_tok: int. The number of experts to select for each token + in the MoE layer. + head_dim: int. The size of each attention head. + layer_norm_epsilon: float. The epsilon value used for every layer norm + in the transformer model. + dropout: float. Dropout probability for the transformer encoder. + sliding_window_size: int. Size of the sliding local window. Defaults to + 4096. + max_sequence_length: int. The maximum sequence length supported by the + model. Defaults to 4096. + dtype: str or `keras.mixed_precision.DTypePolicy`. The dtype to use for + the model's computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Example: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Qwen3-Omni MoE decoder. + model = keras_hub.models.Qwen3OmniMoeBackbone.from_preset("qwen3_omni_moe_7b") + model(input_data) + + # Randomly initialized Qwen3-Omni MoE decoder with custom config. + model = keras_hub.models.Qwen3OmniMoeBackbone( + vocabulary_size=151936, + num_layers=32, + num_query_heads=32, + num_key_value_heads=4, + hidden_dim=4096, + intermediate_dim=11008, + num_experts=8, + num_experts_per_tok=2, + head_dim=128, + max_sequence_length=32768, + ) + model(input_data) + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + hidden_dim, + intermediate_dim, + num_experts, + num_experts_per_tok, + head_dim=None, + layer_norm_epsilon=1e-6, + dropout=0.0, + sliding_window_size=4096, + max_sequence_length=32768, + dtype=None, + **kwargs, + ): + # Set up the config + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.head_dim = head_dim if head_dim is not None else hidden_dim // num_query_heads + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.sliding_window_size = sliding_window_size + self.max_sequence_length = max_sequence_length + + # Token embeddings + self.token_embedding = ReversibleEmbedding( + vocabulary_size, + hidden_dim, + embeddings_initializer=_qwen3_omni_moe_kernel_initializer(), + dtype=dtype, + name="token_embedding", + ) + + # Transformer decoder + self.transformer_decoder = Qwen3OmniMoeTransformerDecoder( + num_layers=num_layers, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + head_dim=head_dim, + layer_norm_epsilon=layer_norm_epsilon, + dropout=dropout, + sliding_window_size=sliding_window_size, + max_sequence_length=max_sequence_length, + dtype=dtype, + name="transformer_decoder", + ) + + # Final layer norm + self.layer_norm = Qwen3OmniMoeLayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="layer_norm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + + # Compute attention mask + attention_mask = ops.cast(padding_mask_input, dtype="bool") + + # Transformer decoder + decoder_outputs = self.transformer_decoder( + hidden_states=x, + attention_mask=attention_mask, + position_ids=None, + cache=None, + cache_update_index=None, + training=None, + ) + + sequence_output = self.layer_norm(decoder_outputs["hidden_states"]) + + super().__init__( + inputs=[token_id_input, padding_mask_input], + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + def call( + self, + inputs, + position_ids=None, + cache=None, + cache_update_index=None, + training=None, + ): + # Handle both dictionary and list inputs (for functional model compatibility) + if isinstance(inputs, dict): + token_ids = inputs["token_ids"] + padding_mask = inputs.get("padding_mask") + else: + # inputs is a list from functional model: [token_ids, padding_mask] + token_ids = inputs[0] + padding_mask = inputs[1] + + # Embed tokens + hidden_states = self.token_embedding(token_ids) + + # Compute attention mask + attention_mask = padding_mask + if attention_mask is not None: + attention_mask = ops.cast(attention_mask, dtype="bool") + else: + attention_mask = ops.ones( + (ops.shape(token_ids)[0], ops.shape(token_ids)[1]), + dtype="bool", + ) + + # Transformer decoder + decoder_outputs = self.transformer_decoder( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + cache=cache, + cache_update_index=cache_update_index, + training=training, + ) + + # Final layer norm + hidden_states = self.layer_norm(decoder_outputs["hidden_states"]) + + if cache_update_index is not None: + return hidden_states, decoder_outputs["cache"] + return hidden_states + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_experts": self.num_experts, + "num_experts_per_tok": self.num_experts_per_tok, + "head_dim": self.head_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "sliding_window_size": self.sliding_window_size, + "max_sequence_length": self.max_sequence_length, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_backbone_test.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_backbone_test.py new file mode 100644 index 0000000000..3b48d5cae6 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_backbone_test.py @@ -0,0 +1,85 @@ +import pytest +from keras import ops + +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_backbone import Qwen3OmniMoeBackbone +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3OmniMoeBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 151936, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 2, + "hidden_dim": 128, + "intermediate_dim": 256, + "num_experts": 4, + "num_experts_per_tok": 2, + "head_dim": 32, + "max_sequence_length": 128, + } + self.input_data = { + "token_ids": ops.ones((2, 16), dtype="int32"), + "padding_mask": ops.ones((2, 16), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=Qwen3OmniMoeBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 16, 128), + run_mixed_precision_check=False, # Disable mixed precision check due to MoE complexity + run_quantization_check=False, # Disable quantization check due to MoE complexity + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=Qwen3OmniMoeBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in Qwen3OmniMoeBackbone.presets: + self.run_preset_test( + cls=Qwen3OmniMoeBackbone, + preset=preset, + input_data=self.input_data, + ) + + def test_cache_functionality(self): + """Test that cache is properly handled and returned.""" + model = Qwen3OmniMoeBackbone(**self.init_kwargs) + + # First forward pass without cache + outputs1 = model(self.input_data) + self.assertEqual(outputs1.shape, (2, 16, 128)) + + # Second forward pass with cache + cache_input = { + "token_ids": ops.ones((2, 1), dtype="int32"), + "padding_mask": ops.ones((2, 1), dtype="int32"), + } + + outputs2, cache = model(cache_input, cache=None, cache_update_index=0) + self.assertEqual(outputs2.shape, (2, 1, 128)) + self.assertIsNotNone(cache) + + # Third forward pass using cache + outputs3, updated_cache = model(cache_input, cache=cache, cache_update_index=1) + self.assertEqual(outputs3.shape, (2, 1, 128)) + self.assertIsNotNone(updated_cache) + + def test_auxiliary_loss(self): + """Test that auxiliary losses are properly computed during training.""" + model = Qwen3OmniMoeBackbone(**self.init_kwargs) + _ = model(self.input_data, training=True) + self.assertTrue( + len(model.losses) > 0, "Auxiliary losses should be present" + ) + for loss in model.losses: + self.assertGreater(loss, 0.0, "Auxiliary loss should be positive") \ No newline at end of file diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm.py new file mode 100644 index 0000000000..ad756c575b --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm.py @@ -0,0 +1,265 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_backbone import Qwen3OmniMoeBackbone +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_causal_lm_preprocessor import ( + Qwen3OmniMoeCausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export( + "keras_hub.models.Qwen3OmniMoeCausalLM", +) +class Qwen3OmniMoeCausalLM(CausalLM): + """An end-to-end Qwen3-Omni MoE model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on plain + text input, or to autoregressively generate plain text similar to the data + used for training. This task can be used for pre-training or fine-tuning a + Qwen3-Omni MoE model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_hub.samplers` objects to control the generation. + By default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()`, and `generate()`. This is done by + default when creating the model with `from_preset()`. + + The Qwen3-Omni MoE architecture leverages a Mixture of Experts (MoE) design + with multimodal capabilities, supporting text, audio, and vision inputs. + Each transformer layer uses a sparse set of experts to process tokens + efficiently, making it suitable for large-scale multimodal AI tasks. + + Args: + backbone: A `keras_hub.models.Qwen3OmniMoeBackbone` instance. + preprocessor: A `keras_hub.models.Qwen3OmniMoeCausalLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Examples: + + Use `generate()` to do text generation. + ```python + qwen3_omni_moe_lm = keras_hub.models.Qwen3OmniMoeCausalLM.from_preset( + "qwen3_omni_moe_7b" + ) + qwen3_omni_moe_lm.generate("I want to say", max_length=30) + + # Generate with batched prompts. + qwen3_omni_moe_lm.generate(["This is a", "Where are you"], max_length=30) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + qwen3_omni_moe_lm = keras_hub.models.Qwen3OmniMoeCausalLM.from_preset( + "qwen3_omni_moe_7b" + ) + qwen3_omni_moe_lm.compile(sampler="top_k") + qwen3_omni_moe_lm.generate("I want to say", max_length=30) + + qwen3_omni_moe_lm.compile(sampler=keras_hub.samplers.BeamSampler(num_beams=2)) + qwen3_omni_moe_lm.generate("I want to say", max_length=30) + ``` + + Use `generate()` without preprocessing. + ```python + prompt = { + # Token ids for " Qwen3-Omni is". + "token_ids": np.array([[2, 12345, 678, 0, 0, 0, 0]] * 2), + # Use `"padding_mask"` to indicate values that should not be overridden. + "padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2), + } + + qwen3_omni_moe_lm = keras_hub.models.Qwen3OmniMoeCausalLM.from_preset( + "qwen3_omni_moe_7b", + preprocessor=None, + ) + qwen3_omni_moe_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + qwen3_omni_moe_lm = keras_hub.models.Qwen3OmniMoeCausalLM.from_preset( + "qwen3_omni_moe_7b" + ) + qwen3_omni_moe_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` with LoRA fine-tuning enabled. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + qwen3_omni_moe_lm = keras_hub.models.Qwen3OmniMoeCausalLM.from_preset( + "qwen3_omni_moe_7b" + ) + qwen3_omni_moe_lm.backbone.enable_lora(rank=4) + qwen3_omni_moe_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + # Token ids for " Qwen3-Omni is a multimodal model" + "token_ids": np.array([[2, 12345, 678, 543, 9876, 1, 0, 0]] * 2), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2), + } + y = np.array([[12345, 678, 543, 9876, 1, 0, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 1, 1, 0, 0, 0]] * 2) + + qwen3_omni_moe_lm = keras_hub.models.Qwen3OmniMoeCausalLM.from_preset( + "qwen3_omni_moe_7b", + preprocessor=None, + ) + qwen3_omni_moe_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + tokenizer = keras_hub.models.Qwen3OmniMoeTokenizer( + proto="qwen3_omni_moe_vocab.spm", + ) + preprocessor = keras_hub.models.Qwen3OmniMoeCausalLMPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_hub.models.Qwen3OmniMoeBackbone( + vocabulary_size=151936, + num_layers=32, + num_query_heads=32, + num_key_value_heads=4, + hidden_dim=4096, + intermediate_dim=11008, + num_experts=8, + num_experts_per_tok=2, + head_dim=128, + max_sequence_length=32768, + ) + qwen3_omni_moe_lm = keras_hub.models.Qwen3OmniMoeCausalLM( + backbone=backbone, + preprocessor=preprocessor, + ) + qwen3_omni_moe_lm.fit(x=features, batch_size=2) + ``` + """ + + backbone_cls = Qwen3OmniMoeBackbone + preprocessor_cls = Qwen3OmniMoeCausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `Qwen3OmniMoeCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + layer = self.backbone.transformer_decoder.layers[i] + x, cache = layer( + x, + cache=cache[i], + cache_update_index=cache_update_index, + ) + updated_cache.append(cache) + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, x, updated_cache + + def compute_loss( + self, + x=None, + y=None, + y_pred=None, + sample_weight=None, + ): + """Compute the loss of the model. + + Args: + x: Input data. + y: Target data. + y_pred: Predictions returned by the model. + sample_weight: Sample weights for the loss computation. + + Returns: + The loss of the model. + """ + if y is None: + return None + + # If y_pred is provided, use it directly + if y_pred is not None: + # y_pred is already computed, just compute the loss + y_pred = y_pred + else: + # Forward pass through the model + y_pred = self(x) + + # Compute cross-entropy loss + y_true = ops.cast(y, dtype="int32") + y_pred = ops.cast(y_pred, dtype="float32") + + # Flatten for loss computation + y_true_flat = ops.reshape(y_true, (-1,)) + y_pred_flat = ops.reshape(y_pred, (-1, ops.shape(y_pred)[-1])) + + # Compute cross-entropy loss + loss = keras.losses.sparse_categorical_crossentropy( + y_true_flat, y_pred_flat, from_logits=True + ) + + # Apply sample weights if provided + if sample_weight is not None: + sample_weight_flat = ops.reshape(sample_weight, (-1,)) + loss = loss * sample_weight_flat + + # Add auxiliary losses from MoE layers + total_loss = ops.mean(loss) + for auxiliary_loss in self.losses: + total_loss += auxiliary_loss + + return total_loss diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm_preprocessor.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm_preprocessor.py new file mode 100644 index 0000000000..a0d6317a93 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm_preprocessor.py @@ -0,0 +1,33 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_backbone import Qwen3OmniMoeBackbone +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_tokenizer import Qwen3OmniMoeTokenizer + + +@keras_hub_export( + "keras_hub.models.Qwen3OmniMoeCausalLMPreprocessor", +) +class Qwen3OmniMoeCausalLMPreprocessor(CausalLMPreprocessor): + """Preprocessor for Qwen3-Omni MoE causal language model. + + This preprocessor handles tokenization and preprocessing for the Qwen3-Omni MoE + model, supporting multimodal inputs including text, audio, and vision. + + Args: + tokenizer: A `Qwen3OmniMoeTokenizer` instance. + sequence_length: int. The length of the packed sequence. + add_start_token: bool. Whether to add the start token. Defaults to True. + add_end_token: bool. Whether to add the end token. Defaults to True. + + Example: + ```python + # Create preprocessor + preprocessor = Qwen3OmniMoeCausalLMPreprocessor.from_preset("qwen3_omni_moe_7b") + + # Preprocess text + preprocessed = preprocessor(["Hello, world!", "How are you?"]) + ``` + """ + + backbone_cls = Qwen3OmniMoeBackbone + tokenizer_cls = Qwen3OmniMoeTokenizer diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm_preprocessor_test.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..14171d24d4 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm_preprocessor_test.py @@ -0,0 +1,218 @@ +import pytest +from keras import ops + +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_causal_lm_preprocessor import Qwen3OmniMoeCausalLMPreprocessor +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_tokenizer import Qwen3OmniMoeTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3OmniMoeCausalLMPreprocessorTest(TestCase): + def setUp(self): + # Create a dummy tokenizer for testing + self.vocabulary = { + "<|endoftext|>": 0, + "<|im_end|>": 1, + "hello": 2, + "world": 3, + "how": 4, + "are": 5, + "you": 6, + "the": 7, + "quick": 8, + "brown": 9, + "fox": 10, + "jumps": 11, + "over": 12, + "lazy": 13, + "dog": 14, + } + self.merges = ["h e", "l l", "o ", "w o", "r l", "d "] + + self.tokenizer = Qwen3OmniMoeTokenizer( + vocabulary=self.vocabulary, + merges=self.merges, + ) + + # Create preprocessor + self.preprocessor = Qwen3OmniMoeCausalLMPreprocessor( + tokenizer=self.tokenizer, + sequence_length=128, + ) + + # Test data + self.test_texts = [ + "Hello, world!", + "How are you today?", + "The quick brown fox jumps over the lazy dog." + ] + + def test_preprocessor_basics(self): + """Test basic preprocessor functionality.""" + # Test preprocessing single text + text = "Hello, world!" + preprocessed = self.preprocessor(text) + + # Should return token_ids and padding_mask + self.assertIn("token_ids", preprocessed) + self.assertIn("padding_mask", preprocessed) + + # Check shapes + self.assertEqual(len(preprocessed["token_ids"].shape), 2) # (batch_size, seq_len) + self.assertEqual(len(preprocessed["padding_mask"].shape), 2) # (batch_size, seq_len) + + def test_preprocessor_batch_processing(self): + """Test batch processing of multiple texts.""" + preprocessed = self.preprocessor(self.test_texts) + + # Should handle multiple texts + batch_size = preprocessed["token_ids"].shape[0] + seq_len = preprocessed["token_ids"].shape[1] + + self.assertEqual(batch_size, len(self.test_texts)) + self.assertEqual(seq_len, 128) # sequence_length + + def test_preprocessor_sequence_length(self): + """Test that preprocessor respects sequence_length parameter.""" + # Test with different sequence lengths + for seq_len in [32, 64, 128]: + preprocessor = Qwen3OmniMoeCausalLMPreprocessor( + tokenizer=self.tokenizer, + sequence_length=seq_len, + ) + + preprocessed = preprocessor("Hello, world!") + actual_seq_len = preprocessed["token_ids"].shape[1] + self.assertEqual(actual_seq_len, seq_len) + + def test_preprocessor_padding(self): + """Test padding behavior.""" + # Test with short text + short_text = "Hello" + preprocessed = self.preprocessor(short_text) + + # Should pad to sequence_length + seq_len = preprocessed["token_ids"].shape[1] + self.assertEqual(seq_len, 128) + + # Padding mask should indicate valid tokens + padding_mask = preprocessed["padding_mask"] + self.assertIn(0, padding_mask.numpy().flatten()) # Should have padding tokens + + def test_preprocessor_truncation(self): + """Test truncation behavior.""" + # Create very long text + long_text = " ".join(["word"] * 200) # Very long text + + preprocessed = self.preprocessor(long_text) + + # Should truncate to sequence_length + seq_len = preprocessed["token_ids"].shape[1] + self.assertEqual(seq_len, 128) + + def test_preprocessor_special_tokens(self): + """Test that special tokens are properly handled.""" + # Test with text that might trigger special token handling + text = "Hello <|im_end|> world" + preprocessed = self.preprocessor(text) + + # Should handle special tokens correctly + self.assertIn("token_ids", preprocessed) + self.assertIn("padding_mask", preprocessed) + + def test_preprocessor_empty_input(self): + """Test preprocessor with empty input.""" + preprocessed = self.preprocessor("") + + # Should handle empty input gracefully + self.assertIn("token_ids", preprocessed) + self.assertIn("padding_mask", preprocessed) + + # Should still have correct shape + seq_len = preprocessed["token_ids"].shape[1] + self.assertEqual(seq_len, 128) + + def test_preprocessor_different_input_types(self): + """Test preprocessor with different input types.""" + # Test with string + preprocessed_str = self.preprocessor("Hello, world!") + + # Test with list of strings + preprocessed_list = self.preprocessor(["Hello", "world"]) + + # Both should work + self.assertIn("token_ids", preprocessed_str) + self.assertIn("token_ids", preprocessed_list) + + def test_preprocessor_from_preset(self): + """Test loading preprocessor from preset.""" + try: + preprocessor = Qwen3OmniMoeCausalLMPreprocessor.from_preset("qwen3_omni_moe_7b") + + # If successful, test basic functionality + preprocessed = preprocessor("test") + self.assertIn("token_ids", preprocessed) + except (ValueError, ImportError): + # Skip if preset not available + pytest.skip("Preset not available") + + def test_preprocessor_consistency(self): + """Test that preprocessor gives consistent results.""" + text = "Hello, world!" + + # Preprocess same text multiple times + results = [] + for _ in range(3): + preprocessed = self.preprocessor(text) + results.append(preprocessed) + + # Results should be identical + for i in range(1, len(results)): + self.assertTrue( + ops.allclose(results[0]["token_ids"], results[i]["token_ids"]) + ) + self.assertTrue( + ops.allclose(results[0]["padding_mask"], results[i]["padding_mask"]) + ) + + def test_preprocessor_tokenizer_integration(self): + """Test integration with tokenizer.""" + # Test that preprocessor uses tokenizer correctly + text = "Hello, world!" + preprocessed = self.preprocessor(text) + + # Token IDs should be valid vocabulary indices + token_ids = preprocessed["token_ids"] + vocab_size = len(self.tokenizer.get_vocabulary()) + + # All token IDs should be within vocabulary range + self.assertTrue(ops.all(token_ids >= 0)) + self.assertTrue(ops.all(token_ids < vocab_size)) + + def test_preprocessor_output_format(self): + """Test that preprocessor output format matches model expectations.""" + preprocessed = self.preprocessor("Hello, world!") + + # Should return dictionary with required keys + required_keys = ["token_ids", "padding_mask"] + for key in required_keys: + self.assertIn(key, preprocessed) + + # Should have correct data types + self.assertEqual(preprocessed["token_ids"].dtype.name, "int32") + self.assertEqual(preprocessed["padding_mask"].dtype.name, "int32") + + def test_preprocessor_with_long_sequences(self): + """Test preprocessor with sequences at the boundary.""" + # Test with sequence exactly at sequence_length + text = " ".join(["word"] * 100) # Long text + + preprocessed = self.preprocessor(text) + + # Should handle long sequences correctly + seq_len = preprocessed["token_ids"].shape[1] + self.assertEqual(seq_len, 128) + + # Should have valid padding mask + padding_mask = preprocessed["padding_mask"] + self.assertIn(0, padding_mask.numpy().flatten()) # Should have padding + self.assertIn(1, padding_mask.numpy().flatten()) # Should have valid tokens diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm_test.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm_test.py new file mode 100644 index 0000000000..1edd709205 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_causal_lm_test.py @@ -0,0 +1,200 @@ +import pytest +from keras import ops + +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_backbone import Qwen3OmniMoeBackbone +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_causal_lm import Qwen3OmniMoeCausalLM +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_causal_lm_preprocessor import Qwen3OmniMoeCausalLMPreprocessor +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3OmniMoeCausalLMTest(TestCase): + def setUp(self): + # Create backbone for testing + self.backbone = Qwen3OmniMoeBackbone( + vocabulary_size=151936, + num_layers=2, + num_query_heads=8, + num_key_value_heads=2, + hidden_dim=256, + intermediate_dim=512, + num_experts=4, + num_experts_per_tok=2, + head_dim=32, + max_sequence_length=512, + ) + + # Create CausalLM model + self.model = Qwen3OmniMoeCausalLM(backbone=self.backbone) + + # Test input data + self.input_data = { + "token_ids": ops.ones((2, 10), dtype="int32"), + "padding_mask": ops.ones((2, 10), dtype="int32"), + } + + def test_causal_lm_basics(self): + """Test basic CausalLM functionality.""" + # Test forward pass + outputs = self.model(self.input_data) + + # Should return logits with correct shape + expected_shape = (2, 10, 151936) # (batch_size, seq_len, vocab_size) + self.assertEqual(outputs.shape, expected_shape) + + def test_causal_lm_generation(self): + """Test text generation functionality.""" + # Test generate method + prompt = "Hello, how are you" + + try: + generated = self.model.generate( + prompt, + max_length=20, + from_logits=True + ) + + # Should return generated text + self.assertIsInstance(generated, str) + self.assertGreater(len(generated), len(prompt)) + except Exception as e: + # Skip if generation fails (expected for untrained model) + pytest.skip(f"Generation test skipped: {e}") + + def test_causal_lm_with_preprocessor(self): + """Test CausalLM with preprocessor.""" + # Create preprocessor (without tokenizer for testing) + try: + preprocessor = Qwen3OmniMoeCausalLMPreprocessor( + tokenizer=None, + sequence_length=128, + ) + + # Create model with preprocessor + model_with_preprocessor = Qwen3OmniMoeCausalLM( + backbone=self.backbone, + preprocessor=preprocessor + ) + + # Test that model can be created + self.assertIsNotNone(model_with_preprocessor) + except Exception as e: + # Skip if preprocessor creation fails + pytest.skip(f"Preprocessor test skipped: {e}") + + def test_causal_lm_training(self): + """Test CausalLM training setup.""" + # Test loss computation + y_true = ops.ones((2, 10), dtype="int32") + + # Compute loss + loss = self.model.compute_loss( + x=self.input_data, + y=y_true, + ) + + # Should return a scalar loss + self.assertIsNotNone(loss) + self.assertEqual(len(loss.shape), 0) # Scalar + + def test_causal_lm_cache_functionality(self): + """Test cache functionality for generation.""" + # Test call_with_cache method + token_ids = ops.ones((2, 1), dtype="int32") + cache = [None] * self.backbone.num_layers + cache_update_index = 0 + + try: + logits, hidden_states, updated_cache = self.model.call_with_cache( + token_ids=token_ids, + cache=cache, + cache_update_index=cache_update_index + ) + + # Should return logits, hidden states, and updated cache + self.assertEqual(logits.shape, (2, 1, 151936)) + self.assertEqual(hidden_states.shape, (2, 1, 256)) + self.assertEqual(len(updated_cache), self.backbone.num_layers) + except Exception as e: + # Skip if cache functionality fails + pytest.skip(f"Cache test skipped: {e}") + + def test_causal_lm_from_preset(self): + """Test loading CausalLM from preset.""" + try: + model = Qwen3OmniMoeCausalLM.from_preset("qwen3_omni_moe_7b") + + # If successful, test basic functionality + test_input = { + "token_ids": ops.ones((1, 5), dtype="int32"), + "padding_mask": ops.ones((1, 5), dtype="int32"), + } + outputs = model(test_input) + self.assertIsNotNone(outputs) + except (ValueError, ImportError): + # Skip if preset not available + pytest.skip("Preset not available") + + def test_causal_lm_parameter_count(self): + """Test that model has reasonable parameter count.""" + param_count = self.model.count_params() + + # Should have reasonable number of parameters + self.assertGreater(param_count, 1000000) # At least 1M parameters + self.assertLess(param_count, 100000000) # Less than 100M for test model + + def test_causal_lm_auxiliary_losses(self): + """Test that auxiliary losses are properly handled.""" + # Forward pass with training=True + outputs = self.model(self.input_data, training=True) + + # Should have auxiliary losses from MoE layers + auxiliary_losses = self.model.losses + self.assertGreaterEqual(len(auxiliary_losses), 0) # May have MoE auxiliary losses + + def test_causal_lm_save_load(self): + """Test model saving and loading.""" + import tempfile + import os + + with tempfile.TemporaryDirectory() as temp_dir: + save_path = os.path.join(temp_dir, "test_model") + + try: + # Save model + self.model.save(save_path) + + # Load model + loaded_model = keras.models.load_model(save_path) + + # Test that loaded model works + outputs = loaded_model(self.input_data) + self.assertEqual(outputs.shape, (2, 10, 151936)) + except Exception as e: + # Skip if save/load fails + pytest.skip(f"Save/load test skipped: {e}") + + def test_causal_lm_different_input_sizes(self): + """Test model with different input sizes.""" + # Test with different sequence lengths + for seq_len in [5, 10, 20]: + input_data = { + "token_ids": ops.ones((2, seq_len), dtype="int32"), + "padding_mask": ops.ones((2, seq_len), dtype="int32"), + } + + outputs = self.model(input_data) + expected_shape = (2, seq_len, 151936) + self.assertEqual(outputs.shape, expected_shape) + + def test_causal_lm_batch_processing(self): + """Test model with different batch sizes.""" + # Test with different batch sizes + for batch_size in [1, 2, 4]: + input_data = { + "token_ids": ops.ones((batch_size, 10), dtype="int32"), + "padding_mask": ops.ones((batch_size, 10), dtype="int32"), + } + + outputs = self.model(input_data) + expected_shape = (batch_size, 10, 151936) + self.assertEqual(outputs.shape, expected_shape) diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_decoder.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_decoder.py new file mode 100644 index 0000000000..dc1c496afd --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_decoder.py @@ -0,0 +1,538 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_attention import Qwen3OmniMoeAttention +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_layernorm import Qwen3OmniMoeLayerNorm +from keras_hub.src.utils.keras_utils import clone_initializer + + +def compute_load_balancing_loss( + router_logits, num_experts, num_experts_per_tok, attention_mask=None +): + """ + Compute the load balancing auxiliary loss for a single MoE layer. + + Args: + router_logits: Tensor of shape (batch_size * seq_len, num_experts). + num_experts: Integer, total number of experts. + num_experts_per_tok: Integer, number of experts to select per token. + attention_mask: Tensor of shape (batch_size, seq_len, seq_len), + optional mask for padding. + + Returns: + Scalar tensor representing the auxiliary loss. + """ + # Compute routing probabilities + routing_weights = ops.softmax( + router_logits, axis=-1 + ) # Shape: (batch_size * seq_len, num_experts) + + # Get top-k experts + _, selected_experts = ops.top_k( + routing_weights, k=num_experts_per_tok + ) # Shape: (batch_size * seq_len, num_experts_per_tok) + + # Create one-hot encoding for selected experts + expert_mask = ops.one_hot( + selected_experts, num_experts + ) # Shape: (batch_size * seq_len, num_experts_per_tok, num_experts) + + if attention_mask is not None: + # Convert attention mask to (batch_size, seq_len) + batch_size, seq_len, _ = ops.shape(attention_mask) + flat_mask = ops.any(attention_mask, axis=-1) + flat_mask = ops.reshape(flat_mask, (-1, 1, 1)) # (batch_size * seq_len, 1, 1) + expert_mask = expert_mask * flat_mask + + # Compute expert usage + expert_usage = ops.mean(expert_mask, axis=0) # Shape: (num_experts,) + expert_usage = ops.mean(expert_usage, axis=0) # Shape: (num_experts,) + + # Compute load balancing loss + num_tokens = ops.sum(ops.any(attention_mask, axis=-1)) if attention_mask is not None else ops.shape(routing_weights)[0] + expert_usage = expert_usage * num_experts + load_balancing_loss = ops.sum(expert_usage) * ops.sum(expert_usage) / (num_experts * num_experts) + + return load_balancing_loss + + +class Qwen3OmniMoeTransformerDecoderLayer(keras.layers.Layer): + """A transformer decoder layer for Qwen3-Omni MoE model. + + This layer implements a complete transformer decoder layer with self-attention + and a sparse mixture-of-experts (MoE) feedforward network. It uses pre-normalization + architecture with RMSNorm for improved training stability. + + Args: + num_query_heads: int. The number of heads for the query projections. + num_key_value_heads: int. The number of heads for the key and value + projections (must be <= num_query_heads). + hidden_dim: int. The size of the transformer hidden state. + intermediate_dim: int. The output dimension of the first Dense layer in + the feedforward network. + num_experts: int. The number of experts in each MoE layer. + num_experts_per_tok: int. The number of experts to select for each token. + head_dim: int. The size of each attention head. + layer_norm_epsilon: float, default 1e-6. The epsilon value used for + layer normalization. + dropout: float, default 0.0. Dropout probability. + sliding_window_size: int, default 4096. Size of the sliding local window. + max_sequence_length: int, default 32768. The maximum sequence length + supported by the model. + dtype: str or `keras.mixed_precision.DTypePolicy`, optional. The dtype + to use for the layer's computations and weights. + + Example: + ```python + # Create decoder layer + layer = Qwen3OmniMoeTransformerDecoderLayer( + num_query_heads=32, + num_key_value_heads=4, + hidden_dim=4096, + intermediate_dim=11008, + num_experts=8, + num_experts_per_tok=2 + ) + + # Apply to input + hidden_states = keras.random.normal((2, 10, 4096)) + outputs = layer(hidden_states) + # outputs["hidden_states"] shape: (2, 10, 4096) + # outputs["cache"] contains attention cache + # outputs["router_logits"] contains MoE routing logits + ``` + """ + + def __init__( + self, + num_query_heads, + num_key_value_heads, + hidden_dim, + intermediate_dim, + num_experts, + num_experts_per_tok, + head_dim, + layer_norm_epsilon=1e-6, + dropout=0.0, + sliding_window_size=4096, + max_sequence_length=32768, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.head_dim = head_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.sliding_window_size = sliding_window_size + self.max_sequence_length = max_sequence_length + + # Self-attention + self.attention = Qwen3OmniMoeAttention( + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + hidden_dim=hidden_dim, + head_dim=head_dim, + layer_norm_epsilon=layer_norm_epsilon, + dropout=dropout, + sliding_window_size=sliding_window_size, + max_sequence_length=max_sequence_length, + dtype=dtype, + name="attention", + ) + + # MoE feedforward + self.moe_feedforward = Qwen3OmniMoeSparseMoeBlock( + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + dtype=dtype, + name="moe_feedforward", + ) + + # Layer norms + self.attention_layer_norm = Qwen3OmniMoeLayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="attention_layer_norm", + ) + self.feedforward_layer_norm = Qwen3OmniMoeLayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="feedforward_layer_norm", + ) + + def call( + self, + hidden_states, + attention_mask=None, + position_ids=None, + cache=None, + cache_update_index=None, + training=None, + ): + residual = hidden_states + hidden_states = self.attention_layer_norm(hidden_states) + + # Self-attention + attention_outputs = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + cache=cache, + cache_update_index=cache_update_index, + training=training, + ) + attention_output = attention_outputs["hidden_states"] + attention_cache = attention_outputs.get("cache") + + # Residual connection + hidden_states = residual + attention_output + + residual = hidden_states + hidden_states = self.feedforward_layer_norm(hidden_states) + + # MoE feedforward + feedforward_outputs = self.moe_feedforward( + hidden_states=hidden_states, + training=training, + ) + feedforward_output = feedforward_outputs["hidden_states"] + + # Residual connection + hidden_states = residual + feedforward_output + + # Collect auxiliary loss + auxiliary_loss = feedforward_outputs.get("auxiliary_loss") + if auxiliary_loss is not None: + self.add_loss(auxiliary_loss) + + return { + "hidden_states": hidden_states, + "cache": attention_cache, + "router_logits": feedforward_outputs.get("router_logits"), + } + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_experts": self.num_experts, + "num_experts_per_tok": self.num_experts_per_tok, + "head_dim": self.head_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "sliding_window_size": self.sliding_window_size, + "max_sequence_length": self.max_sequence_length, + } + ) + return config + + +class Qwen3OmniMoeSparseMoeBlock(keras.layers.Layer): + """A sparse mixture-of-experts (MoE) block for Qwen3-Omni MoE model. + + This layer implements a sparse MoE feedforward network that routes tokens + to a subset of experts based on learned routing probabilities. It uses + top-k expert selection with weighted combination for efficient computation. + + Args: + hidden_dim: int. The size of the transformer hidden state. + intermediate_dim: int. The output dimension of the first Dense layer + in each expert. + num_experts: int. The number of experts in the MoE layer. + num_experts_per_tok: int. The number of experts to select for each token. + dtype: str or `keras.mixed_precision.DTypePolicy`, optional. The dtype + to use for the layer's computations and weights. + + Example: + ```python + # Create MoE block + moe_block = Qwen3OmniMoeSparseMoeBlock( + hidden_dim=4096, + intermediate_dim=11008, + num_experts=8, + num_experts_per_tok=2 + ) + + # Apply to input + hidden_states = keras.random.normal((2, 10, 4096)) + outputs = moe_block(hidden_states) + # outputs["hidden_states"] shape: (2, 10, 4096) + # outputs["router_logits"] shape: (20, 8) - routing logits for each token + ``` + """ + + def __init__( + self, + hidden_dim, + intermediate_dim, + num_experts, + num_experts_per_tok, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + + # Router + self.router = keras.layers.Dense( + num_experts, + use_bias=False, + dtype=dtype, + name="router", + ) + + # Experts + self.experts = [] + for i in range(num_experts): + expert = keras.Sequential([ + keras.layers.Dense( + intermediate_dim, + activation="silu", + dtype=dtype, + name=f"expert_{i}_up", + ), + keras.layers.Dense( + hidden_dim, + dtype=dtype, + name=f"expert_{i}_down", + ), + ], name=f"expert_{i}") + self.experts.append(expert) + + def call(self, hidden_states, training=None): + batch_size, seq_len, hidden_dim = ops.shape(hidden_states) + + # Flatten for routing + hidden_states_flat = ops.reshape(hidden_states, (-1, hidden_dim)) + + # Get router logits + router_logits = self.router(hidden_states_flat) + + # Get top-k experts and their weights + routing_weights = ops.softmax(router_logits, axis=-1) + gating_weights, selected_experts = ops.top_k(routing_weights, k=self.num_experts_per_tok) + gating_weights /= ops.sum(gating_weights, axis=-1, keepdims=True) + + # Create a mask for the selected experts + expert_mask = ops.one_hot(selected_experts, self.num_experts, dtype=gating_weights.dtype) + + # Compute expert outputs + expert_outputs_list = [] + for expert in self.experts: + expert_outputs_list.append(expert(hidden_states_flat)) + expert_outputs = ops.stack(expert_outputs_list, axis=1) + + # Weight expert outputs by gating probabilities + weighted_expert_output = ops.einsum("...SE,...ED->...SD", expert_mask, expert_outputs) + gating_weights = ops.expand_dims(gating_weights, axis=-1) + final_output = ops.sum(weighted_expert_output * gating_weights, axis=1) + + # Reshape back + final_output = ops.reshape(final_output, (batch_size, seq_len, hidden_dim)) + + # Compute auxiliary loss for load balancing + auxiliary_loss = compute_load_balancing_loss( + router_logits, self.num_experts, self.num_experts_per_tok + ) + + return { + "hidden_states": final_output, + "router_logits": router_logits, + "auxiliary_loss": auxiliary_loss, + } + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_experts": self.num_experts, + "num_experts_per_tok": self.num_experts_per_tok, + } + ) + return config + + +@keras_hub_export("keras_hub.models.Qwen3OmniMoeTransformerDecoder") +class Qwen3OmniMoeTransformerDecoder(keras.layers.Layer): + """A transformer decoder for Qwen3-Omni MoE model. + + This layer implements a stack of transformer decoder layers with sparse + mixture-of-experts (MoE) feedforward networks. Each layer includes self-attention + and MoE feedforward with pre-normalization architecture. + + Args: + num_layers: int. The number of transformer decoder layers. + num_query_heads: int. The number of heads for the query projections. + num_key_value_heads: int. The number of heads for the key and value + projections (must be <= num_query_heads). + hidden_dim: int. The size of the transformer hidden state. + intermediate_dim: int. The output dimension of the first Dense layer in + the feedforward network. + num_experts: int. The number of experts in each MoE layer. + num_experts_per_tok: int. The number of experts to select for each token. + head_dim: int. The size of each attention head. + layer_norm_epsilon: float, default 1e-6. The epsilon value used for + layer normalization. + dropout: float, default 0.0. Dropout probability. + sliding_window_size: int, default 4096. Size of the sliding local window. + max_sequence_length: int, default 32768. The maximum sequence length + supported by the model. + dtype: str or `keras.mixed_precision.DTypePolicy`, optional. The dtype + to use for the layer's computations and weights. + + Example: + ```python + # Create transformer decoder + decoder = Qwen3OmniMoeTransformerDecoder( + num_layers=32, + num_query_heads=32, + num_key_value_heads=4, + hidden_dim=4096, + intermediate_dim=11008, + num_experts=8, + num_experts_per_tok=2 + ) + + # Apply to input + hidden_states = keras.random.normal((2, 10, 4096)) + outputs = decoder(hidden_states) + # outputs["hidden_states"] shape: (2, 10, 4096) + # outputs["cache"] contains attention caches from all layers + # outputs["all_router_logits"] contains MoE routing logits from all layers + ``` + """ + + def __init__( + self, + num_layers, + num_query_heads, + num_key_value_heads, + hidden_dim, + intermediate_dim, + num_experts, + num_experts_per_tok, + head_dim, + layer_norm_epsilon=1e-6, + dropout=0.0, + sliding_window_size=4096, + max_sequence_length=32768, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.head_dim = head_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + self.sliding_window_size = sliding_window_size + self.max_sequence_length = max_sequence_length + + # Transformer layers + self.layers = [] + for i in range(num_layers): + layer = Qwen3OmniMoeTransformerDecoderLayer( + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + head_dim=head_dim, + layer_norm_epsilon=layer_norm_epsilon, + dropout=dropout, + sliding_window_size=sliding_window_size, + max_sequence_length=max_sequence_length, + dtype=dtype, + name=f"layer_{i}", + ) + self.layers.append(layer) + + def call( + self, + hidden_states, + attention_mask=None, + position_ids=None, + cache=None, + cache_update_index=None, + training=None, + ): + # Initialize cache if not provided + if cache is None: + cache = [None] * self.num_layers + + # Process through layers + all_hidden_states = [] + all_router_logits = [] + current_cache = [] + + for i, layer in enumerate(self.layers): + layer_outputs = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + cache=cache[i], + cache_update_index=cache_update_index, + training=training, + ) + + hidden_states = layer_outputs["hidden_states"] + current_cache.append(layer_outputs.get("cache")) + + if "router_logits" in layer_outputs: + all_router_logits.append(layer_outputs["router_logits"]) + + return { + "hidden_states": hidden_states, + "cache": current_cache, + "all_router_logits": all_router_logits, + } + + def get_config(self): + config = super().get_config() + config.update( + { + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_experts": self.num_experts, + "num_experts_per_tok": self.num_experts_per_tok, + "head_dim": self.head_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + "sliding_window_size": self.sliding_window_size, + "max_sequence_length": self.max_sequence_length, + } + ) + return config diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_layernorm.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_layernorm.py new file mode 100644 index 0000000000..b6197e4a56 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_layernorm.py @@ -0,0 +1,68 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export + + +@keras_hub_export("keras_hub.models.Qwen3OmniMoeLayerNorm") +class Qwen3OmniMoeLayerNorm(keras.layers.Layer): + """RMS Normalization layer for Qwen3-Omni MoE model. + + RMSNorm (Root Mean Square Normalization) is a normalization technique + that normalizes inputs by the root mean square of the inputs, without + centering them around zero. This is commonly used in modern transformer + architectures like Qwen models. + + Args: + epsilon: float, default 1e-6. A small value added to the denominator + for numerical stability. + dtype: str or `keras.mixed_precision.DTypePolicy`. The dtype to use for + the layer's computations and weights. + + Example: + ```python + # Create an RMSNorm layer + layer = Qwen3OmniMoeLayerNorm(epsilon=1e-6) + + # Apply to input tensor + inputs = keras.random.normal((2, 10, 128)) + outputs = layer(inputs) # Shape: (2, 10, 128) + ``` + """ + + def __init__( + self, + epsilon=1e-6, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + self.gamma = self.add_weight( + name="gamma", + shape=(input_shape[-1],), + initializer="ones", + dtype=self.dtype, + ) + super().build(input_shape) + + def call(self, inputs): + # Compute mean of squares + variance = ops.mean(ops.square(inputs), axis=-1, keepdims=True) + + # Normalize + normalized = inputs / ops.sqrt(variance + self.epsilon) + + # Scale + return self.gamma * normalized + + def get_config(self): + config = super().get_config() + config.update( + { + "epsilon": self.epsilon, + } + ) + return config \ No newline at end of file diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_presets.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_presets.py new file mode 100644 index 0000000000..8661509243 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_presets.py @@ -0,0 +1,236 @@ +"""Qwen3-Omni MoE model presets.""" + +from keras_hub.src.utils.preset_utils import register_presets + +# Qwen3-Omni MoE model presets +backbone_presets = { + "qwen3_omni_moe_7b": { + "metadata": { + "description": "Qwen3-Omni MoE 7B model with multimodal capabilities", + "parameters": "7B", + "model_size": "7B", + }, + "kaggle_handle": "kaggle://keras/qwen3-omni-moe-7b/keras/qwen3_omni_moe_7b", + "config": { + "vocabulary_size": 151936, + "num_layers": 32, + "num_query_heads": 32, + "hidden_dim": 4096, + "intermediate_dim": 11008, + "num_key_value_heads": 4, + "rope_max_wavelength": 10000, + "rope_scaling_factor": 1.0, + "layer_norm_epsilon": 1e-6, + "dropout": 0.0, + "num_experts": 8, + "num_experts_per_tok": 2, + "audio_config": { + "vocab_size": 1024, + "hidden_size": 4096, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "intermediate_size": 11008, + "hidden_act": "silu", + "max_position_embeddings": 32768, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "attention_bias": False, + "attention_dropout": 0.0, + "num_experts": 8, + "num_experts_per_tok": 2, + "audio_vocab_size": 1024, + "audio_hidden_size": 4096, + "audio_num_hidden_layers": 32, + "audio_num_attention_heads": 32, + "audio_num_key_value_heads": 4, + "audio_intermediate_size": 11008, + "audio_hidden_act": "silu", + "audio_max_position_embeddings": 32768, + "audio_initializer_range": 0.02, + "audio_rms_norm_eps": 1e-6, + "audio_use_cache": True, + "audio_rope_theta": 10000.0, + "audio_rope_scaling": None, + "audio_attention_bias": False, + "audio_attention_dropout": 0.0, + "audio_num_experts": 8, + "audio_num_experts_per_tok": 2, + }, + "vision_config": { + "hidden_size": 4096, + "intermediate_size": 11008, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "image_size": 448, + "patch_size": 14, + "num_channels": 3, + "hidden_act": "silu", + "layer_norm_eps": 1e-6, + "attention_dropout": 0.0, + "initializer_range": 0.02, + "initializer_factor": 1.0, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "attention_bias": False, + "spatial_merge_size": 2, + "num_experts": 8, + "num_experts_per_tok": 2, + }, + "text_config": { + "vocab_size": 151936, + "hidden_size": 4096, + "intermediate_size": 11008, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "hidden_act": "silu", + "max_position_embeddings": 32768, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "attention_bias": False, + "attention_dropout": 0.0, + "num_experts": 8, + "num_experts_per_tok": 2, + }, + "thinker_config": { + "audio_config": { + "vocab_size": 1024, + "hidden_size": 4096, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "intermediate_size": 11008, + "hidden_act": "silu", + "max_position_embeddings": 32768, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "attention_bias": False, + "attention_dropout": 0.0, + "num_experts": 8, + "num_experts_per_tok": 2, + }, + "vision_config": { + "hidden_size": 4096, + "intermediate_size": 11008, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "image_size": 448, + "patch_size": 14, + "num_channels": 3, + "hidden_act": "silu", + "layer_norm_eps": 1e-6, + "attention_dropout": 0.0, + "initializer_range": 0.02, + "initializer_factor": 1.0, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "attention_bias": False, + "spatial_merge_size": 2, + "num_experts": 8, + "num_experts_per_tok": 2, + }, + "text_config": { + "vocab_size": 151936, + "hidden_size": 4096, + "intermediate_size": 11008, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "hidden_act": "silu", + "max_position_embeddings": 32768, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "attention_bias": False, + "attention_dropout": 0.0, + "num_experts": 8, + "num_experts_per_tok": 2, + }, + }, + "talker_config": { + "text_config": { + "vocab_size": 151936, + "hidden_size": 4096, + "intermediate_size": 11008, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "hidden_act": "silu", + "max_position_embeddings": 32768, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "attention_bias": False, + "attention_dropout": 0.0, + "num_experts": 8, + "num_experts_per_tok": 2, + }, + "code_predictor_config": { + "vocab_size": 1024, + "hidden_size": 4096, + "intermediate_size": 11008, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "hidden_act": "silu", + "max_position_embeddings": 32768, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "attention_bias": False, + "attention_dropout": 0.0, + "num_experts": 8, + "num_experts_per_tok": 2, + }, + }, + "code2wav_config": { + "vocab_size": 1024, + "hidden_size": 4096, + "intermediate_size": 11008, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "hidden_act": "silu", + "max_position_embeddings": 32768, + "initializer_range": 0.02, + "rms_norm_eps": 1e-6, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "attention_bias": False, + "attention_dropout": 0.0, + "num_experts": 8, + "num_experts_per_tok": 2, + }, + "enable_audio_output": True, + "im_start_token_id": 151644, + "im_end_token_id": 151645, + "tts_pad_token_id": 151671, + "tts_bos_token_id": 151672, + "tts_eos_token_id": 151673, + "system_token_id": 8948, + "user_token_id": 872, + "assistant_token_id": 77091, + }, + }, +} diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_tokenizer.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_tokenizer.py new file mode 100644 index 0000000000..c0766d9a02 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_tokenizer.py @@ -0,0 +1,58 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_backbone import Qwen3OmniMoeBackbone +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export( + "keras_hub.tokenizers.Qwen3OmniMoeTokenizer", +) +class Qwen3OmniMoeTokenizer(BytePairTokenizer): + """Tokenizer for Qwen3-Omni MoE model. + + This tokenizer implements byte-pair encoding (BPE) for Qwen3-Omni MoE models, + handling special tokens like BOS (beginning of sequence) and EOS (end of + sequence). It supports multimodal capabilities including text, audio, and vision. + + Args: + vocabulary: Dictionary mapping tokens to token IDs, or path to + vocabulary file. + merges: List of BPE merges, or path to merges file. + bos_token: Beginning of sequence token. Defaults to None. + eos_token: End of sequence token. Defaults to "<|im_end|>". + misc_special_tokens: Set of additional special tokens. Defaults to + empty set. + + Example: + ```python + # Create tokenizer + tokenizer = Qwen3OmniMoeTokenizer.from_preset("qwen3_omni_moe_7b") + + # Tokenize text + tokens = tokenizer("Hello, world!") + # Returns: {'token_ids': array([[151644, 8948, 77091, 151645, 0, 0]]), 'padding_mask': array([[1, 1, 1, 1, 0, 0]])} + ``` + """ + + backbone_cls = Qwen3OmniMoeBackbone + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + # Add EOS token + eos_token = "<|im_end|>" + self._add_special_token(eos_token, "end_token") + + pad_token = "<|endoftext|>" + self._add_special_token(pad_token, "pad_token") + + self.start_token_id = None + self.start_token = None + + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) diff --git a/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_tokenizer_test.py b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_tokenizer_test.py new file mode 100644 index 0000000000..76341248f5 --- /dev/null +++ b/keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_tokenizer_test.py @@ -0,0 +1,115 @@ +import pytest +from keras import ops + +from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_tokenizer import Qwen3OmniMoeTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class Qwen3OmniMoeTokenizerTest(TestCase): + def setUp(self): + # Create a dummy vocabulary for testing + self.vocabulary = { + "<|endoftext|>": 0, + "<|im_end|>": 1, + "hello": 2, + "world": 3, + "how": 4, + "are": 5, + "you": 6, + "the": 7, + "quick": 8, + "brown": 9, + "fox": 10, + } + self.merges = ["h e", "l l", "o ", "w o", "r l", "d "] + + self.tokenizer = Qwen3OmniMoeTokenizer( + vocabulary=self.vocabulary, + merges=self.merges, + ) + + def test_tokenizer_basics(self): + """Test basic tokenizer functionality.""" + # Test tokenization + text = "hello world" + tokens = self.tokenizer(text) + + # Should return token_ids and padding_mask + self.assertIn("token_ids", tokens) + self.assertIn("padding_mask", tokens) + + # Check shapes + self.assertEqual(len(tokens["token_ids"].shape), 2) # (batch_size, seq_len) + self.assertEqual(len(tokens["padding_mask"].shape), 2) # (batch_size, seq_len) + + def test_tokenizer_special_tokens(self): + """Test that special tokens are properly added.""" + # Check that special tokens are in vocabulary + self.assertIn("<|endoftext|>", self.tokenizer.get_vocabulary()) + self.assertIn("<|im_end|>", self.tokenizer.get_vocabulary()) + + def test_tokenizer_batch_processing(self): + """Test batch processing of multiple texts.""" + texts = ["hello world", "how are you", "the quick brown fox"] + tokens = self.tokenizer(texts) + + # Should handle multiple texts + batch_size = tokens["token_ids"].shape[0] + self.assertEqual(batch_size, 3) + + def test_tokenizer_detokenization(self): + """Test detokenization round-trip.""" + original_text = "hello world" + tokens = self.tokenizer(original_text) + detokenized = self.tokenizer.detokenize(tokens["token_ids"]) + + # Should be able to detokenize (though may not be exact due to subword tokenization) + self.assertIsInstance(detokenized, str) + + def test_tokenizer_from_preset(self): + """Test loading tokenizer from preset.""" + # This test will be skipped if no presets are available + try: + tokenizer = Qwen3OmniMoeTokenizer.from_preset("qwen3_omni_moe_7b") + # If successful, test basic functionality + tokens = tokenizer("test") + self.assertIn("token_ids", tokens) + except (ValueError, ImportError): + # Skip if preset not available or dependencies missing + pytest.skip("Preset not available or dependencies missing") + + def test_tokenizer_vocabulary_size(self): + """Test that vocabulary size is correct.""" + vocab_size = len(self.tokenizer.get_vocabulary()) + self.assertGreater(vocab_size, 0) + self.assertEqual(vocab_size, len(self.vocabulary)) + + def test_tokenizer_padding(self): + """Test tokenizer padding behavior.""" + # Test with different length inputs + short_text = "hello" + long_text = "hello world how are you today" + + short_tokens = self.tokenizer(short_text) + long_tokens = self.tokenizer(long_text) + + # Both should have valid outputs + self.assertIn("token_ids", short_tokens) + self.assertIn("token_ids", long_tokens) + + def test_tokenizer_empty_input(self): + """Test tokenizer with empty input.""" + tokens = self.tokenizer("") + + # Should handle empty input gracefully + self.assertIn("token_ids", tokens) + self.assertIn("padding_mask", tokens) + + def test_tokenizer_special_characters(self): + """Test tokenizer with special characters.""" + text = "Hello, world! How are you? I'm fine, thanks." + tokens = self.tokenizer(text) + + # Should handle special characters + self.assertIn("token_ids", tokens) + self.assertIn("padding_mask", tokens)