Skip to content

Commit d9a3ac1

Browse files
Eitan Porathengtaoguo
authored andcommitted
Add Qwen3 Omni Vision Encoder
cast pyconfig types gemma3 vision is nope
1 parent c0abc4c commit d9a3ac1

File tree

14 files changed

+2022
-24
lines changed

14 files changed

+2022
-24
lines changed

src/MaxText/configs/base.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,13 @@ vision_output_dim_for_vit: 4096
897897
pixel_shuffle_ratio_for_vit: 0.5
898898
projector_dropout_for_vit: 0.0
899899

900+
# Qwen3-OmniMoe vision encoder
901+
spatial_merge_size_for_vit: 2
902+
out_hidden_size_for_vit: 512
903+
temporal_patch_size_for_vit: 2
904+
num_position_embeddings_for_vit: 1024
905+
deepstack_visual_indexes_for_vit: []
906+
900907
# Subslice shape in the form of "x,y,z" when using pathways (single controller).
901908
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
902909
subslice_shape: ""

src/MaxText/configs/models/qwen3-omni-30b-a3b.yml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,25 @@ base_moe_mlp_dim: 768
3434
norm_topk_prob: true
3535

3636
# RoPE Settings
37-
rope_max_timescale: 10_000_000
37+
rope_max_timescale: 1_000_000
38+
max_position_embeddings: 65536
3839

3940
# General Model Settings
4041
enable_dropout: False
42+
43+
# Vision Encoder Configuration
44+
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py
45+
image_size_for_vit: 768
46+
hidden_size_for_vit: 1152
47+
intermediate_size_for_vit: 4304
48+
num_attention_heads_for_vit: 16
49+
num_hidden_layers_for_vit: 27
50+
num_channels_for_vit: 3
51+
patch_size_for_vit: 16
52+
temporal_patch_size_for_vit: 2
53+
spatial_merge_size_for_vit: 2
54+
out_hidden_size_for_vit: 2048
55+
num_position_embeddings_for_vit: 2304
56+
deepstack_visual_indexes_for_vit: [8, 16, 24]
57+
58+
use_multimodal: true

src/MaxText/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,6 +1194,11 @@ class VisionTower(BaseModel):
11941194
num_hidden_layers_for_vit: int = Field(34, description="Number of hidden layers in the Vision Transformer.")
11951195
rope_theta_for_vit: int = Field(10000, description="RoPE theta value for the Vision Transformer.")
11961196
vision_output_dim_for_vit: int = Field(4096, description="Final output dimension of the vision-to-language projection.")
1197+
spatial_merge_size_for_vit: int = Field(2, description="Spatial merge factor for vision patches.")
1198+
out_hidden_size_for_vit: int = Field(512, description="Output dimension of ViT.")
1199+
temporal_patch_size_for_vit: int = Field(2, description="Temporal patch size for video inputs.")
1200+
num_position_embeddings_for_vit: int = Field(1024, description="Number of position embeddings for ViT.")
1201+
deepstack_visual_indexes_for_vit: list[int] = Field([], description="Layer indices to extract deep visual features.")
11971202

11981203

11991204
class VisionProjector(BaseModel):

src/MaxText/layers/attention_mla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ def __call__(
671671
slot: Optional[int] = None,
672672
page_state: Optional[page_manager.PageState] = None,
673673
bidirectional_mask: Optional[Any] = None,
674+
rope_kwargs: dict | None = None,
674675
) -> Array:
675676
"""Forward pass for MLA, reusing `AttentionOp` for the actual attention.
676677

src/MaxText/layers/attentions.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import dataclasses
1818
import functools
19-
from typing import Any, Iterable, Optional, Tuple, Union
19+
from typing import Any, Iterable, Optional, Tuple, Union, cast
2020

2121
from jax.ad_checkpoint import checkpoint_name
2222
from jax.sharding import Mesh, NamedSharding
@@ -63,6 +63,7 @@
6363
from MaxText.layers.embeddings import (
6464
LLaMARotaryEmbedding,
6565
LlamaVisionRotaryEmbedding,
66+
Qwen3OmniMoeVisionRotaryEmbedding,
6667
RotaryEmbedding,
6768
YarnRotaryEmbedding,
6869
Qwen3NextRotaryEmbedding,
@@ -720,15 +721,29 @@ def init_rotary_embedding(self):
720721
rope_type = self.config.rope_type.lower()
721722
rope_use_scale = self.config.rope_use_scale
722723
if self.is_vision:
723-
rotary_embedding = LlamaVisionRotaryEmbedding(
724-
image_size=self.config.image_size_for_vit,
725-
patch_size=self.config.patch_size_for_vit,
726-
hidden_size=self.config.hidden_size_for_vit,
727-
num_attention_heads=self.config.num_attention_heads_for_vit,
728-
rope_theta=self.config.rope_theta_for_vit,
729-
fprop_dtype=self.dtype,
730-
rngs=self.rngs,
731-
)
724+
if self.config.model_name.startswith("qwen3-omni"):
725+
rotary_embedding = Qwen3OmniMoeVisionRotaryEmbedding(
726+
hidden_size=self.config.hidden_size_for_vit,
727+
num_attention_heads=self.config.num_attention_heads_for_vit,
728+
spatial_merge_size=self.config.spatial_merge_size_for_vit,
729+
rope_theta=self.config.rope_theta_for_vit,
730+
fprop_dtype=self.dtype,
731+
rngs=self.rngs,
732+
)
733+
elif self.config.model_name.startswith("llama4"):
734+
rotary_embedding = LlamaVisionRotaryEmbedding(
735+
image_size=self.config.image_size_for_vit,
736+
patch_size=self.config.patch_size_for_vit,
737+
hidden_size=self.config.hidden_size_for_vit,
738+
num_attention_heads=self.config.num_attention_heads_for_vit,
739+
rope_theta=self.config.rope_theta_for_vit,
740+
cast_as_fprop_dtype=True,
741+
fprop_dtype=self.dtype,
742+
rngs=self.rngs,
743+
)
744+
else:
745+
raise ValueError(f"Unsupported model type for vision rotary embedding: {self.config.model_name}")
746+
732747
elif self.config.model_name.startswith("llama3.1") or rope_type.startswith("llama3.1"):
733748
rotary_embedding = LLaMARotaryEmbedding(
734749
min_timescale=self.config.rope_min_timescale,
@@ -784,18 +799,28 @@ def init_rotary_embedding(self):
784799
)
785800
return rotary_embedding
786801

787-
def apply_rotary_embedding(self, inputs: Array, inputs_positions: Optional[Array | None] = None):
802+
def apply_rotary_embedding(
803+
self, inputs: Array, inputs_positions: Optional[Array | None] = None, rope_kwargs: dict | None = None
804+
):
788805
"""Applies rotary embeddings, handling different model types.
789806
790807
Args:
791808
inputs: The input tensor to apply rotary embeddings to.
792809
inputs_positions: The positions of the inputs.
793-
name: A name for the embedding layer.
810+
rope_kwargs: A dictionary of keyword arguments for the rotary embedding.
794811
795812
Returns:
796813
The input tensor with rotary embeddings applied.
797814
"""
798-
return self.rotary_embedding(inputs, inputs_positions)
815+
if isinstance(self.rotary_embedding, Qwen3OmniMoeVisionRotaryEmbedding):
816+
# For Qwen3OmniMoe vision, pass static dimensions from kwargs.
817+
num_frames = rope_kwargs.get("num_frames")
818+
height = rope_kwargs.get("height")
819+
width = rope_kwargs.get("width")
820+
# Type cast required: Omni rotary embedding uses different __call__ parameters than other embeddings.
821+
return cast(Qwen3OmniMoeVisionRotaryEmbedding, self.rotary_embedding)(inputs, num_frames, height, width)
822+
else:
823+
return self.rotary_embedding(inputs, inputs_positions)
799824

800825
def init_kv_caches(self, inputs_kv_shape: Tuple):
801826
"""Initializes KVCache.
@@ -878,6 +903,7 @@ def __call__(
878903
slot: Optional[int] = None,
879904
page_state: Optional[page_manager.PageState] = None,
880905
bidirectional_mask: Any = None,
906+
rope_kwargs: dict | None = None,
881907
):
882908
"""Applies Attention on the input data.
883909
@@ -952,8 +978,8 @@ def __call__(
952978
use_qk_norm = self.use_qk_norm and use_rope
953979

954980
if use_rope:
955-
query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions)
956-
key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions)
981+
query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs)
982+
key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs)
957983

958984
if use_qk_norm and is_llama4_decoder_block:
959985
l2_norm = L2Norm(eps=self.config.normalization_layer_epsilon)

src/MaxText/layers/decoders.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,14 @@ def _apply_embedding(
558558

559559
# Merge the image embeddings with the text embeddings for multimodal models
560560
if image_embeddings is not None and cfg.use_multimodal:
561-
if cfg.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e"]:
561+
if cfg.model_name in [
562+
"gemma3-4b",
563+
"gemma3-12b",
564+
"gemma3-27b",
565+
"llama4-17b-16e",
566+
"llama4-17b-128e",
567+
"qwen3-omni-30b-a3b",
568+
]:
562569
y = multimodal_utils.merge_mm_embeddings(
563570
text_embeddings=y,
564571
vision_embeddings=image_embeddings,

0 commit comments

Comments
 (0)