Skip to content

Commit cbedda2

Browse files
authored
FEAT: [model] qwen3-omni (xorbitsai#4137)
1 parent 1a82be8 commit cbedda2

File tree

1 file changed

+60
-11
lines changed

1 file changed

+60
-11
lines changed

xinference/model/llm/transformers/multimodal/qwen-omni.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from threading import Thread
2020
from typing import Any, Dict, Iterator, List, Optional, Tuple
2121

22+
import torch
23+
2224
from .....types import (
2325
ChatCompletion,
2426
ChatCompletionAudio,
@@ -35,20 +37,31 @@
3537

3638
@register_transformer
3739
@register_non_default_model("qwen2.5-omni")
38-
class Qwen2_5OmniChatModel(PytorchMultiModalModel):
40+
@register_non_default_model("Qwen3-Omni-Thinking")
41+
@register_non_default_model("Qwen3-Omni-Instruct")
42+
class QwenOmniChatModel(PytorchMultiModalModel):
3943
DEFAULT_SYSTEM_PROMPT = (
4044
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
4145
"capable of perceiving auditory and visual inputs, as well as generating text and speech."
4246
)
4347

48+
def __init__(self, *args, **kwargs):
49+
super().__init__(*args, **kwargs)
50+
# 2.5 or 3
51+
model_family = self.model_family.model_family or self.model_family.model_name
52+
self._omni_version = "2.5" if "2.5" in model_family else "3"
53+
4454
@classmethod
4555
def match_json(
4656
cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
4757
) -> bool:
4858
if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
4959
return False
5060
llm_family = model_family.model_family or model_family.model_name
51-
if "qwen2.5-omni".lower() in llm_family.lower():
61+
if (
62+
"qwen2.5-omni".lower() in llm_family.lower()
63+
or "qwen3-omni".lower() in llm_family.lower()
64+
):
5265
return True
5366
return False
5467

@@ -58,15 +71,25 @@ def decide_device(self):
5871
self._device = device
5972

6073
def load_processor(self):
61-
from transformers import Qwen2_5OmniProcessor
74+
if self._omni_version == "2.5":
75+
from transformers import Qwen2_5OmniProcessor as QwenOminiProcessor
76+
else:
77+
from transformers import Qwen3OmniMoeProcessor as QwenOminiProcessor
6278

63-
self._processor = Qwen2_5OmniProcessor.from_pretrained(
79+
self._processor = QwenOminiProcessor.from_pretrained(
6480
self.model_path, trust_remote_code=True
6581
)
6682
self._tokenizer = self._processor.tokenizer
6783

6884
def load_multimodal_model(self):
69-
from transformers import Qwen2_5OmniForConditionalGeneration
85+
if self._omni_version == "2.5":
86+
from transformers import (
87+
Qwen2_5OmniForConditionalGeneration as QwenOmniForConditionalGeneration,
88+
)
89+
else:
90+
from transformers import (
91+
Qwen3OmniMoeForConditionalGeneration as QwenOmniForConditionalGeneration,
92+
)
7093

7194
# for multiple GPU, set back to auto to make multiple devices work
7295
device = "auto" if self._device == "cuda" else self._device
@@ -79,7 +102,7 @@ def load_multimodal_model(self):
79102
kwargs = self.apply_bnb_quantization(kwargs)
80103
logger.debug("Loading model with extra kwargs: %s", kwargs)
81104

82-
self._model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
105+
self._model = QwenOmniForConditionalGeneration.from_pretrained(
83106
self.model_path,
84107
torch_dtype="auto",
85108
device_map=device,
@@ -181,11 +204,37 @@ def generate_non_streaming(
181204
inputs = self.build_inputs_from_messages(messages, generate_config) # type: ignore
182205
use_audio_in_video = generate_config.get("use_audio_in_video", True)
183206
gen_kwargs = dict(**inputs, **config, use_audio_in_video=use_audio_in_video)
184-
generated_ids, audio = self._model.generate(**gen_kwargs)
185-
generated_ids_trimmed = [
186-
out_ids[len(in_ids) :]
187-
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
188-
]
207+
# === Run model.generate() (handle both (ids, audio) and ids-only cases) ===
208+
result = self._model.generate(**gen_kwargs)
209+
if isinstance(result, tuple) and len(result) == 2:
210+
# Qwen2.5-Omni returns (generated_ids, audio)
211+
generated_ids, audio = result
212+
else:
213+
# Qwen3-Omni returns only generated_ids
214+
generated_ids, audio = result, None
215+
if hasattr(generated_ids, "sequences"):
216+
generated_ids = generated_ids.sequences
217+
218+
# === Handle text decoding ===
219+
input_len = inputs.input_ids.shape[1]
220+
# Ensure we have a consistent 2D structure
221+
# Normalize to list[list[int]]
222+
if isinstance(generated_ids, torch.Tensor):
223+
generated_ids = generated_ids.tolist()
224+
elif isinstance(generated_ids, list) and all(
225+
isinstance(x, int) for x in generated_ids
226+
):
227+
# Single sequence as flat list of ints
228+
generated_ids = [generated_ids]
229+
elif isinstance(generated_ids, list) and all(
230+
isinstance(x, list) for x in generated_ids
231+
):
232+
pass # already correct
233+
else:
234+
raise TypeError(f"Unexpected generated_ids type: {type(generated_ids)}")
235+
236+
# Remove prompt tokens
237+
generated_ids_trimmed = [out_ids[input_len:] for out_ids in generated_ids]
189238
output_text = self._processor.batch_decode(
190239
generated_ids_trimmed,
191240
skip_special_tokens=True,

0 commit comments

Comments
 (0)