@@ -814,6 +814,78 @@ def reshape_kernel(input_tensor, target_shape):
814814 return mapping
815815
816816
817+ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING (config , scan_layers = False ):
818+ """Returns mapping from MaxText to HuggingFace Qwen3-Omni weight paths.
819+
820+ This function combines mappings from different modalities (text, vision, audio, etc.)
821+ into a unified parameter mapping for the multi-modal Qwen3-Omni model.
822+
823+ Args:
824+ config (dict): Model configuration dictionary containing modality-specific configs.
825+ scan_layers (bool, optional): Whether the model uses scanned layers. Defaults to False.
826+
827+ Returns:
828+ dict: Combined mapping from all modalities.
829+ """
830+ # Collect all modality mappings
831+ mapping = {}
832+
833+ # Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function
834+ num_experts_text = config ["thinker_config" ]["text_config" ].get ("num_experts" , 0 )
835+ n_layers_text = config ["thinker_config" ]["text_config" ]["num_hidden_layers" ]
836+ text_mapping = QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING (
837+ config = {"num_hidden_layers" : n_layers_text , "num_experts" : num_experts_text }, scan_layers = scan_layers
838+ )
839+
840+ # Add "thinker." prefix to text mapping values
841+ for key , value in text_mapping .items ():
842+ text_mapping [key ] = [f"thinker.{ v } " for v in value ] if isinstance (value , list ) else f"thinker.{ value } "
843+ mapping .update (text_mapping )
844+
845+ # TODO(hengtaoguo): Add vision, audio, and other modality mappings here similarly
846+ # mapping.update(vision_mapping), mapping.update(audio_mapping), etc.
847+
848+ return mapping
849+
850+
851+ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , scan_layers = False , saving_to_hf = False ):
852+ """Creates parameter transformation functions for Qwen3-Omni.
853+
854+ This function provides a dictionary of transformation functions (hooks) for
855+ converting Qwen3-Omni model parameters between MaxText and Hugging Face formats.
856+ It handles embedding padding and kernel reshaping.
857+
858+ Args:
859+ config (dict): Model configuration dictionary, including
860+ 'num_hidden_layers' and optionally 'num_experts'.
861+ scan_layers (bool, optional): Whether the model uses scanned layers.
862+ Defaults to False.
863+ saving_to_hf (bool, optional): The direction of conversion. True for
864+ MaxText to Hugging Face, False for the reverse. Defaults to False.
865+
866+ Returns:
867+ dict: A dictionary mapping MaxText parameter names to their corresponding
868+ transformation functions.
869+ """
870+ # Collect all modality hooks
871+ mapping = {}
872+
873+ # Text hooks, reusing QWEN3-MOE hook function
874+ num_experts_text = config ["thinker_config" ]["text_config" ].get ("num_experts" , 0 )
875+ n_layers_text = config ["thinker_config" ]["text_config" ]["num_hidden_layers" ]
876+ text_hooks = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN (
877+ config = {"num_hidden_layers" : n_layers_text , "num_experts" : num_experts_text },
878+ scan_layers = scan_layers ,
879+ saving_to_hf = saving_to_hf ,
880+ )
881+ mapping .update (text_hooks )
882+
883+ # TODO(hengtaoguo): Add vision, audio, and other modality mappings here similarly
884+ # mapping.update(vision_hooks), mapping.update(audio_hooks), etc.
885+
886+ return mapping
887+
888+
817889def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING (config , scan_layers = False ):
818890 """
819891 Returns a dictionary mapping from MaxText parameter names to
@@ -1007,6 +1079,7 @@ def from_hf():
10071079 "qwen3-30b-a3b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
10081080 "qwen3-235b-a22b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
10091081 "qwen3-coder-480b-a35b" : QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING ,
1082+ "qwen3-omni-30b-a3b" : QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING ,
10101083}
10111084
10121085HOOK_FNS = {
@@ -1028,4 +1101,5 @@ def from_hf():
10281101 "qwen3-30b-a3b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
10291102 "qwen3-235b-a22b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
10301103 "qwen3-coder-480b-a35b" : QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
1104+ "qwen3-omni-30b-a3b" : QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN ,
10311105}
0 commit comments