-
Notifications
You must be signed in to change notification settings - Fork 301
model added qwen3_omni #2426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
lukiod
wants to merge
2
commits into
keras-team:master
Choose a base branch
from
lukiod:feature/add-qwen3-omni-model
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
model added qwen3_omni #2426
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_backbone import Qwen3OmniMoeBackbone | ||
from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_presets import backbone_presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(backbone_presets, Qwen3OmniMoeBackbone) |
185 changes: 185 additions & 0 deletions
185
keras_hub/src/models/qwen3_omni_moe/qwen3_omni_moe_attention.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding | ||
from keras_hub.src.models.qwen3_omni_moe.qwen3_omni_moe_layernorm import Qwen3OmniMoeLayerNorm | ||
|
||
|
||
@keras_hub_export("keras_hub.models.Qwen3OmniMoeAttention") | ||
class Qwen3OmniMoeAttention(keras.layers.Layer): | ||
"""Multi-head attention for Qwen3-Omni MoE model.""" | ||
|
||
def __init__( | ||
self, | ||
num_query_heads, | ||
num_key_value_heads, | ||
hidden_dim, | ||
head_dim, | ||
layer_norm_epsilon=1e-6, | ||
dropout=0.0, | ||
sliding_window_size=4096, | ||
max_sequence_length=32768, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
super().__init__(dtype=dtype, **kwargs) | ||
self.num_query_heads = num_query_heads | ||
self.num_key_value_heads = num_key_value_heads | ||
self.hidden_dim = hidden_dim | ||
self.head_dim = head_dim if head_dim is not None else hidden_dim // num_query_heads | ||
self.layer_norm_epsilon = layer_norm_epsilon | ||
self.dropout = dropout | ||
self.sliding_window_size = sliding_window_size | ||
self.max_sequence_length = max_sequence_length | ||
|
||
# Query projection | ||
self.query_projection = keras.layers.Dense( | ||
num_query_heads * self.head_dim, | ||
use_bias=False, | ||
dtype=dtype, | ||
name="query_projection", | ||
) | ||
|
||
# Key projection | ||
self.key_projection = keras.layers.Dense( | ||
num_key_value_heads * self.head_dim, | ||
use_bias=False, | ||
dtype=dtype, | ||
name="key_projection", | ||
) | ||
|
||
# Value projection | ||
self.value_projection = keras.layers.Dense( | ||
num_key_value_heads * self.head_dim, | ||
use_bias=False, | ||
dtype=dtype, | ||
name="value_projection", | ||
) | ||
|
||
# Output projection | ||
self.output_projection = keras.layers.Dense( | ||
hidden_dim, | ||
use_bias=False, | ||
dtype=dtype, | ||
name="output_projection", | ||
) | ||
|
||
# Rotary embedding | ||
self.rotary_embedding = RotaryEmbedding( | ||
max_wavelength=10000, | ||
scaling_factor=1.0, | ||
dtype=dtype, | ||
name="rotary_embedding", | ||
) | ||
|
||
def call( | ||
self, | ||
hidden_states, | ||
attention_mask=None, | ||
position_ids=None, | ||
cache=None, | ||
cache_update_index=None, | ||
training=None, | ||
): | ||
batch_size, seq_len, hidden_dim = ops.shape(hidden_states) | ||
|
||
# Project to query, key, value | ||
query = self.query_projection(hidden_states) | ||
key = self.key_projection(hidden_states) | ||
value = self.value_projection(hidden_states) | ||
|
||
# Reshape for multi-head attention | ||
query = ops.reshape( | ||
query, (batch_size, seq_len, self.num_query_heads, self.head_dim) | ||
) | ||
key = ops.reshape( | ||
key, (batch_size, seq_len, self.num_key_value_heads, self.head_dim) | ||
) | ||
value = ops.reshape( | ||
value, (batch_size, seq_len, self.num_key_value_heads, self.head_dim) | ||
) | ||
|
||
# Apply rotary embedding | ||
if position_ids is not None: | ||
query = self.rotary_embedding(query, position_ids) | ||
key = self.rotary_embedding(key, position_ids) | ||
|
||
# Handle cache | ||
if cache is not None: | ||
if cache_update_index is not None: | ||
# Update cache | ||
key = ops.concatenate([cache["key"], key], axis=1) | ||
value = ops.concatenate([cache["value"], value], axis=1) | ||
else: | ||
# Use cache | ||
key = cache["key"] | ||
value = cache["value"] | ||
lukiod marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# Update cache | ||
new_cache = { | ||
"key": key, | ||
"value": value, | ||
} | ||
|
||
# Transpose for attention | ||
query = ops.transpose(query, (0, 2, 1, 3)) # (batch_size, num_heads, seq_len, head_dim) | ||
key = ops.transpose(key, (0, 2, 1, 3)) | ||
value = ops.transpose(value, (0, 2, 1, 3)) | ||
|
||
# Handle grouped query attention (GQA) | ||
# Repeat key and value for grouped query attention | ||
if self.num_key_value_heads < self.num_query_heads: | ||
num_groups = self.num_query_heads // self.num_key_value_heads | ||
key = ops.repeat(key, num_groups, axis=1) | ||
value = ops.repeat(value, num_groups, axis=1) | ||
|
||
# Compute attention scores | ||
attention_scores = ops.matmul(query, ops.transpose(key, (0, 1, 3, 2))) | ||
attention_scores = attention_scores / ops.sqrt(self.head_dim) | ||
|
||
# Apply attention mask | ||
if attention_mask is not None: | ||
if len(attention_mask.shape) == 2: | ||
# Convert 2D mask to 4D for broadcasting | ||
attention_mask = ops.expand_dims(attention_mask, axis=1) | ||
attention_mask = ops.expand_dims(attention_mask, axis=1) | ||
attention_scores = ops.where( | ||
attention_mask, attention_scores, ops.full_like(attention_scores, -1e9) | ||
) | ||
|
||
# Apply softmax | ||
attention_weights = ops.softmax(attention_scores, axis=-1) | ||
|
||
# Apply attention to values | ||
attention_output = ops.matmul(attention_weights, value) | ||
|
||
# Transpose back | ||
attention_output = ops.transpose(attention_output, (0, 2, 1, 3)) | ||
|
||
# Reshape and project | ||
attention_output = ops.reshape( | ||
attention_output, (batch_size, seq_len, self.num_query_heads * self.head_dim) | ||
) | ||
attention_output = self.output_projection(attention_output) | ||
|
||
return { | ||
"hidden_states": attention_output, | ||
"cache": new_cache, | ||
} | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"num_query_heads": self.num_query_heads, | ||
"num_key_value_heads": self.num_key_value_heads, | ||
"hidden_dim": self.hidden_dim, | ||
"head_dim": self.head_dim, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
"dropout": self.dropout, | ||
"sliding_window_size": self.sliding_window_size, | ||
"max_sequence_length": self.max_sequence_length, | ||
} | ||
) | ||
return config |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.