2626 Qwen3OmniMoeCode2WavConfig ,
2727)
2828
29- from megatron .bridge .models .qwen_omni .thinker_model import Qwen3OmniMoeThinkerModel
30- from megatron .bridge .models .qwen_omni .transformer_config import Qwen3OmniTransformerConfig
29+ from megatron .bridge .models .qwen_omni .modeling_qwen3_omni . thinker_model import Qwen3OmniMoeThinkerModel
30+ from megatron .bridge .models .qwen_omni .modeling_qwen3_omni . transformer_config import Qwen3OmniTransformerConfig
3131
3232
3333class Qwen3OmniMoeModel (MegatronModule ):
@@ -73,23 +73,26 @@ def set_input_tensor(self, input_tensor) -> None:
7373
7474 def freeze (
7575 self ,
76- freeze_language_model : bool ,
77- freeze_vision_model : bool ,
78- freeze_vision_projection : bool ,
76+ freeze_language_model : bool = False ,
77+ freeze_vision_model : bool = False ,
78+ freeze_vision_projection : bool = False ,
79+ freeze_audio_model : bool = False ,
7980 ):
8081 """Freeze model modules.
8182
8283 Make specific modules non-trainable by setting requires_grad to False.
8384
8485 Args:
8586 freeze_language_model (bool): Freeze the language model module.
86- freeze_vision_model (bool): Freeze the vision model module (patch_embed, blocks, pos_embed) .
87+ freeze_vision_model (bool): Freeze the vision model module.
8788 freeze_vision_projection (bool): Freeze the vision projection modules (merger and deepstack_merger_list).
89+ freeze_audio_model (bool): Freeze the audio model module.
8890 """
8991 return self .thinker .freeze (
9092 freeze_language_model ,
9193 freeze_vision_model ,
92- freeze_vision_projection
94+ freeze_vision_projection ,
95+ freeze_audio_model ,
9396 )
9497
9598 def forward (
@@ -113,6 +116,7 @@ def forward(
113116 feature_attention_mask = None ,
114117 audio_feature_lengths = None ,
115118 cp_img_num : list [int ] = None ,
119+ images_padded : list [bool ] = None ,
116120 use_audio_in_video = None ,
117121 video_second_per_grid = None ,
118122 ** kwargs ,
@@ -136,6 +140,7 @@ def forward(
136140 feature_attention_mask = feature_attention_mask ,
137141 audio_feature_lengths = audio_feature_lengths ,
138142 cp_img_num = cp_img_num ,
143+ images_padded = images_padded ,
139144 use_audio_in_video = use_audio_in_video ,
140145 video_second_per_grid = video_second_per_grid ,
141146 ** kwargs ,
0 commit comments