|
16 | 16 |
|
17 | 17 | import dataclasses |
18 | 18 | import functools |
19 | | -from typing import Any, Iterable, Optional, Tuple, Union |
| 19 | +from typing import Any, Iterable, Optional, Tuple, Union, cast |
20 | 20 |
|
21 | 21 | from jax.ad_checkpoint import checkpoint_name |
22 | 22 | from jax.sharding import Mesh, NamedSharding |
|
63 | 63 | from MaxText.layers.embeddings import ( |
64 | 64 | LLaMARotaryEmbedding, |
65 | 65 | LlamaVisionRotaryEmbedding, |
| 66 | + Qwen3OmniMoeVisionRotaryEmbedding, |
66 | 67 | RotaryEmbedding, |
67 | 68 | YarnRotaryEmbedding, |
68 | 69 | Qwen3NextRotaryEmbedding, |
@@ -720,15 +721,29 @@ def init_rotary_embedding(self): |
720 | 721 | rope_type = self.config.rope_type.lower() |
721 | 722 | rope_use_scale = self.config.rope_use_scale |
722 | 723 | if self.is_vision: |
723 | | - rotary_embedding = LlamaVisionRotaryEmbedding( |
724 | | - image_size=self.config.image_size_for_vit, |
725 | | - patch_size=self.config.patch_size_for_vit, |
726 | | - hidden_size=self.config.hidden_size_for_vit, |
727 | | - num_attention_heads=self.config.num_attention_heads_for_vit, |
728 | | - rope_theta=self.config.rope_theta_for_vit, |
729 | | - fprop_dtype=self.dtype, |
730 | | - rngs=self.rngs, |
731 | | - ) |
| 724 | + if self.config.model_name.startswith("qwen3-omni"): |
| 725 | + rotary_embedding = Qwen3OmniMoeVisionRotaryEmbedding( |
| 726 | + hidden_size=self.config.hidden_size_for_vit, |
| 727 | + num_attention_heads=self.config.num_attention_heads_for_vit, |
| 728 | + spatial_merge_size=self.config.spatial_merge_size_for_vit, |
| 729 | + rope_theta=self.config.rope_theta_for_vit, |
| 730 | + fprop_dtype=self.dtype, |
| 731 | + rngs=self.rngs, |
| 732 | + ) |
| 733 | + elif self.config.model_name.startswith("llama4"): |
| 734 | + rotary_embedding = LlamaVisionRotaryEmbedding( |
| 735 | + image_size=self.config.image_size_for_vit, |
| 736 | + patch_size=self.config.patch_size_for_vit, |
| 737 | + hidden_size=self.config.hidden_size_for_vit, |
| 738 | + num_attention_heads=self.config.num_attention_heads_for_vit, |
| 739 | + rope_theta=self.config.rope_theta_for_vit, |
| 740 | + cast_as_fprop_dtype=True, |
| 741 | + fprop_dtype=self.dtype, |
| 742 | + rngs=self.rngs, |
| 743 | + ) |
| 744 | + else: |
| 745 | + raise ValueError(f"Unsupported model type for vision rotary embedding: {self.config.model_name}") |
| 746 | + |
732 | 747 | elif self.config.model_name.startswith("llama3.1") or rope_type.startswith("llama3.1"): |
733 | 748 | rotary_embedding = LLaMARotaryEmbedding( |
734 | 749 | min_timescale=self.config.rope_min_timescale, |
@@ -784,18 +799,28 @@ def init_rotary_embedding(self): |
784 | 799 | ) |
785 | 800 | return rotary_embedding |
786 | 801 |
|
787 | | - def apply_rotary_embedding(self, inputs: Array, inputs_positions: Optional[Array | None] = None): |
| 802 | + def apply_rotary_embedding( |
| 803 | + self, inputs: Array, inputs_positions: Optional[Array | None] = None, rope_kwargs: dict | None = None |
| 804 | + ): |
788 | 805 | """Applies rotary embeddings, handling different model types. |
789 | 806 |
|
790 | 807 | Args: |
791 | 808 | inputs: The input tensor to apply rotary embeddings to. |
792 | 809 | inputs_positions: The positions of the inputs. |
793 | | - name: A name for the embedding layer. |
| 810 | + rope_kwargs: A dictionary of keyword arguments for the rotary embedding. |
794 | 811 |
|
795 | 812 | Returns: |
796 | 813 | The input tensor with rotary embeddings applied. |
797 | 814 | """ |
798 | | - return self.rotary_embedding(inputs, inputs_positions) |
| 815 | + if isinstance(self.rotary_embedding, Qwen3OmniMoeVisionRotaryEmbedding): |
| 816 | + # For Qwen3OmniMoe vision, pass static dimensions from kwargs. |
| 817 | + num_frames = rope_kwargs.get("num_frames") |
| 818 | + height = rope_kwargs.get("height") |
| 819 | + width = rope_kwargs.get("width") |
| 820 | + # Type cast required: Omni rotary embedding uses different __call__ parameters than other embeddings. |
| 821 | + return cast(Qwen3OmniMoeVisionRotaryEmbedding, self.rotary_embedding)(inputs, num_frames, height, width) |
| 822 | + else: |
| 823 | + return self.rotary_embedding(inputs, inputs_positions) |
799 | 824 |
|
800 | 825 | def init_kv_caches(self, inputs_kv_shape: Tuple): |
801 | 826 | """Initializes KVCache. |
@@ -878,6 +903,7 @@ def __call__( |
878 | 903 | slot: Optional[int] = None, |
879 | 904 | page_state: Optional[page_manager.PageState] = None, |
880 | 905 | bidirectional_mask: Any = None, |
| 906 | + rope_kwargs: dict | None = None, |
881 | 907 | ): |
882 | 908 | """Applies Attention on the input data. |
883 | 909 |
|
@@ -952,8 +978,8 @@ def __call__( |
952 | 978 | use_qk_norm = self.use_qk_norm and use_rope |
953 | 979 |
|
954 | 980 | if use_rope: |
955 | | - query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions) |
956 | | - key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions) |
| 981 | + query = self.apply_rotary_embedding(query, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs) |
| 982 | + key = self.apply_rotary_embedding(key, inputs_positions=inputs_positions, rope_kwargs=rope_kwargs) |
957 | 983 |
|
958 | 984 | if use_qk_norm and is_llama4_decoder_block: |
959 | 985 | l2_norm = L2Norm(eps=self.config.normalization_layer_epsilon) |
|
0 commit comments