Skip to content

Commit b043824

Browse files
Merge pull request #2561 from AI-Hypercomputer:hengtaoguo-vl
PiperOrigin-RevId: 827750023
2 parents 84f3ad6 + ca5282b commit b043824

File tree

6 files changed

+141
-1
lines changed

6 files changed

+141
-1
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Model config for Qwen3-Omni-30B-A3B
16+
17+
# Core Architectural Parameters
18+
decoder_block: "qwen3_moe"
19+
base_emb_dim: 2048
20+
base_mlp_dim: 768
21+
base_num_query_heads: 32
22+
base_num_kv_heads: 4
23+
base_num_decoder_layers: 48
24+
head_dim: 128
25+
mlp_activations: ["silu", "linear"]
26+
vocab_size: 152064
27+
normalization_layer_epsilon: 1.0e-6
28+
use_qk_norm: True
29+
30+
# MoE Specific Parameters
31+
num_experts: 128
32+
num_experts_per_tok: 8
33+
base_moe_mlp_dim: 768
34+
norm_topk_prob: true
35+
36+
# RoPE Settings
37+
rope_max_timescale: 10_000_000
38+
39+
# General Model Settings
40+
enable_dropout: False

src/MaxText/pyconfig.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ def validate_model_name(s: str) -> bool:
454454
"qwen3-30b-a3b",
455455
"qwen3-480b-a35b",
456456
"qwen3-next-80b-a3b",
457+
"qwen3-omni-30b-a3b",
457458
"gpt3-175b",
458459
"gpt3-22b",
459460
"gpt3-6b",

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,18 @@ def _build_single_axis_stacked_tensor(
182182
return np.stack(tensors_to_stack, axis=axis_to_stack)
183183

184184

185+
def _get_hf_model(model_id: str, token: str):
186+
"""Loads the HuggingFace model based on model_id."""
187+
# Some models require special classes to import
188+
if model_id in ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]:
189+
from transformers import Qwen3OmniMoeForConditionalGeneration # pylint: disable=import-outside-toplevel
190+
191+
hf_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(model_id, token=token)
192+
else:
193+
hf_model = AutoModelForCausalLM.from_pretrained(model_id, token=token)
194+
return hf_model
195+
196+
185197
def main(argv: Sequence[str]) -> None:
186198
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
187199
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Suppress TensorFlow logging
@@ -217,7 +229,7 @@ def main(argv: Sequence[str]) -> None:
217229
# Load HuggingFace model, config, and state_dict
218230
max_logging.log(f"Loading HuggingFace model: {model_id}...")
219231
hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token)
220-
hf_model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token)
232+
hf_model = _get_hf_model(model_id, token=hf_token)
221233
hf_state_dict_numpy = hf_model.state_dict()
222234
for k, v in hf_state_dict_numpy.items():
223235
hf_state_dict_numpy[k] = v.numpy()

src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,17 @@
469469
vocab_size=151936,
470470
)
471471

472+
qwen3_omni_30b_a3b_config = transformers.Qwen3OmniMoeConfig(
473+
# TODO(hengtaoguo): Pure-text Omni model, need to fill in visual/audio/code2wav parts
474+
architectures=["Qwen3OmniMoeForConditionalGeneration"],
475+
thinker_config={
476+
"text_config": {
477+
"num_hidden_layers": 48,
478+
"num_experts": 128,
479+
}
480+
},
481+
)
482+
472483
HF_MODEL_CONFIGS = {
473484
"gemma2-2b": gemma2_2b_config,
474485
"gemma2-9b": gemma2_9b_config,
@@ -489,4 +500,5 @@
489500
"qwen3-30b-a3b": qwen3_30b_a3b_thinking_2507_config,
490501
"qwen3-235b-a22b": qwen3_235b_a22b_thinking_2507_config,
491502
"qwen3-480b-a35b": qwen3_coder_480b_a35b_config,
503+
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
492504
}

src/MaxText/utils/ckpt_conversion/utils/param_mapping.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,78 @@ def reshape_kernel(input_tensor, target_shape):
814814
return mapping
815815

816816

817+
def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False):
818+
"""Returns mapping from MaxText to HuggingFace Qwen3-Omni weight paths.
819+
820+
This function combines mappings from different modalities (text, vision, audio, etc.)
821+
into a unified parameter mapping for the multi-modal Qwen3-Omni model.
822+
823+
Args:
824+
config (dict): Model configuration dictionary containing modality-specific configs.
825+
scan_layers (bool, optional): Whether the model uses scanned layers. Defaults to False.
826+
827+
Returns:
828+
dict: Combined mapping from all modalities.
829+
"""
830+
# Collect all modality mappings
831+
mapping = {}
832+
833+
# Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function
834+
num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
835+
n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
836+
text_mapping = QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(
837+
config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text}, scan_layers=scan_layers
838+
)
839+
840+
# Add "thinker." prefix to text mapping values
841+
for key, value in text_mapping.items():
842+
text_mapping[key] = [f"thinker.{v}" for v in value] if isinstance(value, list) else f"thinker.{value}"
843+
mapping.update(text_mapping)
844+
845+
# TODO(hengtaoguo): Add vision, audio, and other modality mappings here similarly
846+
# mapping.update(vision_mapping), mapping.update(audio_mapping), etc.
847+
848+
return mapping
849+
850+
851+
def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving_to_hf=False):
852+
"""Creates parameter transformation functions for Qwen3-Omni.
853+
854+
This function provides a dictionary of transformation functions (hooks) for
855+
converting Qwen3-Omni model parameters between MaxText and Hugging Face formats.
856+
It handles embedding padding and kernel reshaping.
857+
858+
Args:
859+
config (dict): Model configuration dictionary, including
860+
'num_hidden_layers' and optionally 'num_experts'.
861+
scan_layers (bool, optional): Whether the model uses scanned layers.
862+
Defaults to False.
863+
saving_to_hf (bool, optional): The direction of conversion. True for
864+
MaxText to Hugging Face, False for the reverse. Defaults to False.
865+
866+
Returns:
867+
dict: A dictionary mapping MaxText parameter names to their corresponding
868+
transformation functions.
869+
"""
870+
# Collect all modality hooks
871+
mapping = {}
872+
873+
# Text hooks, reusing QWEN3-MOE hook function
874+
num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
875+
n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
876+
text_hooks = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(
877+
config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text},
878+
scan_layers=scan_layers,
879+
saving_to_hf=saving_to_hf,
880+
)
881+
mapping.update(text_hooks)
882+
883+
# TODO(hengtaoguo): Add vision, audio, and other modality mappings here similarly
884+
# mapping.update(vision_hooks), mapping.update(audio_hooks), etc.
885+
886+
return mapping
887+
888+
817889
def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False):
818890
"""
819891
Returns a dictionary mapping from MaxText parameter names to
@@ -1007,6 +1079,7 @@ def from_hf():
10071079
"qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
10081080
"qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
10091081
"qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
1082+
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING,
10101083
}
10111084

10121085
HOOK_FNS = {
@@ -1028,4 +1101,5 @@ def from_hf():
10281101
"qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
10291102
"qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
10301103
"qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
1104+
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN,
10311105
}

src/MaxText/utils/ckpt_conversion/utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
"qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B-Thinking-2507",
7373
"qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B-Thinking-2507",
7474
"qwen3-480b-a35b": "Qwen/Qwen3-Coder-480B-A35B-Instruct",
75+
"qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
7576
}
7677

7778

0 commit comments

Comments
 (0)