1919from  threading  import  Thread 
2020from  typing  import  Any , Dict , Iterator , List , Optional , Tuple 
2121
22+ import  torch 
23+ 
2224from  .....types  import  (
2325    ChatCompletion ,
2426    ChatCompletionAudio ,
3537
3638@register_transformer  
3739@register_non_default_model ("qwen2.5-omni" ) 
38- class  Qwen2_5OmniChatModel (PytorchMultiModalModel ):
40+ @register_non_default_model ("Qwen3-Omni-Thinking" ) 
41+ @register_non_default_model ("Qwen3-Omni-Instruct" ) 
42+ class  QwenOmniChatModel (PytorchMultiModalModel ):
3943    DEFAULT_SYSTEM_PROMPT  =  (
4044        "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, " 
4145        "capable of perceiving auditory and visual inputs, as well as generating text and speech." 
4246    )
4347
48+     def  __init__ (self , * args , ** kwargs ):
49+         super ().__init__ (* args , ** kwargs )
50+         # 2.5 or 3 
51+         model_family  =  self .model_family .model_family  or  self .model_family .model_name 
52+         self ._omni_version  =  "2.5"  if  "2.5"  in  model_family  else  "3" 
53+ 
4454    @classmethod  
4555    def  match_json (
4656        cls , model_family : "LLMFamilyV2" , model_spec : "LLMSpecV1" , quantization : str 
4757    ) ->  bool :
4858        if  model_spec .model_format  not  in   ["pytorch" , "gptq" , "awq" , "bnb" ]:
4959            return  False 
5060        llm_family  =  model_family .model_family  or  model_family .model_name 
51-         if  "qwen2.5-omni" .lower () in  llm_family .lower ():
61+         if  (
62+             "qwen2.5-omni" .lower () in  llm_family .lower ()
63+             or  "qwen3-omni" .lower () in  llm_family .lower ()
64+         ):
5265            return  True 
5366        return  False 
5467
@@ -58,15 +71,25 @@ def decide_device(self):
5871        self ._device  =  device 
5972
6073    def  load_processor (self ):
61-         from  transformers  import  Qwen2_5OmniProcessor 
74+         if  self ._omni_version  ==  "2.5" :
75+             from  transformers  import  Qwen2_5OmniProcessor  as  QwenOminiProcessor 
76+         else :
77+             from  transformers  import  Qwen3OmniMoeProcessor  as  QwenOminiProcessor 
6278
63-         self ._processor  =  Qwen2_5OmniProcessor .from_pretrained (
79+         self ._processor  =  QwenOminiProcessor .from_pretrained (
6480            self .model_path , trust_remote_code = True 
6581        )
6682        self ._tokenizer  =  self ._processor .tokenizer 
6783
6884    def  load_multimodal_model (self ):
69-         from  transformers  import  Qwen2_5OmniForConditionalGeneration 
85+         if  self ._omni_version  ==  "2.5" :
86+             from  transformers  import  (
87+                 Qwen2_5OmniForConditionalGeneration  as  QwenOmniForConditionalGeneration ,
88+             )
89+         else :
90+             from  transformers  import  (
91+                 Qwen3OmniMoeForConditionalGeneration  as  QwenOmniForConditionalGeneration ,
92+             )
7093
7194        # for multiple GPU, set back to auto to make multiple devices work 
7295        device  =  "auto"  if  self ._device  ==  "cuda"  else  self ._device 
@@ -79,7 +102,7 @@ def load_multimodal_model(self):
79102        kwargs  =  self .apply_bnb_quantization (kwargs )
80103        logger .debug ("Loading model with extra kwargs: %s" , kwargs )
81104
82-         self ._model  =  Qwen2_5OmniForConditionalGeneration .from_pretrained (
105+         self ._model  =  QwenOmniForConditionalGeneration .from_pretrained (
83106            self .model_path ,
84107            torch_dtype = "auto" ,
85108            device_map = device ,
@@ -181,11 +204,37 @@ def generate_non_streaming(
181204        inputs  =  self .build_inputs_from_messages (messages , generate_config )  # type: ignore 
182205        use_audio_in_video  =  generate_config .get ("use_audio_in_video" , True )
183206        gen_kwargs  =  dict (** inputs , ** config , use_audio_in_video = use_audio_in_video )
184-         generated_ids , audio  =  self ._model .generate (** gen_kwargs )
185-         generated_ids_trimmed  =  [
186-             out_ids [len (in_ids ) :]
187-             for  in_ids , out_ids  in  zip (inputs .input_ids , generated_ids )
188-         ]
207+         # === Run model.generate() (handle both (ids, audio) and ids-only cases) === 
208+         result  =  self ._model .generate (** gen_kwargs )
209+         if  isinstance (result , tuple ) and  len (result ) ==  2 :
210+             # Qwen2.5-Omni returns (generated_ids, audio) 
211+             generated_ids , audio  =  result 
212+         else :
213+             # Qwen3-Omni returns only generated_ids 
214+             generated_ids , audio  =  result , None 
215+         if  hasattr (generated_ids , "sequences" ):
216+             generated_ids  =  generated_ids .sequences 
217+ 
218+         # === Handle text decoding === 
219+         input_len  =  inputs .input_ids .shape [1 ]
220+         # Ensure we have a consistent 2D structure 
221+         # Normalize to list[list[int]] 
222+         if  isinstance (generated_ids , torch .Tensor ):
223+             generated_ids  =  generated_ids .tolist ()
224+         elif  isinstance (generated_ids , list ) and  all (
225+             isinstance (x , int ) for  x  in  generated_ids 
226+         ):
227+             # Single sequence as flat list of ints 
228+             generated_ids  =  [generated_ids ]
229+         elif  isinstance (generated_ids , list ) and  all (
230+             isinstance (x , list ) for  x  in  generated_ids 
231+         ):
232+             pass   # already correct 
233+         else :
234+             raise  TypeError (f"Unexpected generated_ids type: { type (generated_ids )}  " )
235+ 
236+         # Remove prompt tokens 
237+         generated_ids_trimmed  =  [out_ids [input_len :] for  out_ids  in  generated_ids ]
189238        output_text  =  self ._processor .batch_decode (
190239            generated_ids_trimmed ,
191240            skip_special_tokens = True ,
0 commit comments