Skip to content

mtmd: qwen3 audio support (qwen3-omni and qwen3-asr)#19441

Open
ngxson wants to merge 9 commits intoggml-org:masterfrom
ngxson:xsn/qwen3a
Open

mtmd: qwen3 audio support (qwen3-omni and qwen3-asr)#19441
ngxson wants to merge 9 commits intoggml-org:masterfrom
ngxson:xsn/qwen3a

Conversation

@ngxson
Copy link
Copy Markdown
Contributor

@ngxson ngxson commented Feb 8, 2026

Status:

  • qwen3-omni-moe working (vision + audio input)
  • qwen3-asr working

@github-actions github-actions bot added examples python python script changes labels Feb 8, 2026
@samshipengs
Copy link
Copy Markdown

Any updates on this?

@yimlin
Copy link
Copy Markdown

yimlin commented Mar 16, 2026

any updates?

QuentinFuxa pushed a commit to QuentinFuxa/llama.cpp that referenced this pull request Mar 18, 2026
Add support for Qwen3-ASR-1.7B model (Qwen3ASRForConditionalGeneration):
- New QWEN3A projector type for audio-only ASR models
- Conv2d encoder (3 layers, stride=2 each, 8x time downsampling)
- Whisper-like transformer encoder (24 layers)
- MLP projector: Linear(1024,1024) -> GELU -> Linear(1024,2048)
- Conversion tested: both mmproj and decoder GGUF files work
- Basic inference tested: model loads, encodes audio, generates output

Based on PR ggml-org#19441 by ngxson (WIP qwen3 audio),
adapted for Qwen3-ASR-only architecture (no vision, no deepstack).
Our attention extraction API (llama_set_attn_heads/llama_get_attn_ith) is untouched.
@michoecho
Copy link
Copy Markdown

michoecho commented Mar 25, 2026

I wrote a working Qwen3-ASR support for my own use at https://github.com/michoecho/llama.cpp/commits/qwen3_asr_support. (I successfully used it to transcribe some lectures in Chinese). I don't know if it's good enough for upstreaming, because I wasn't thinking about qwen3-omni at all. (I have no idea what "deepstack" is). But you could use it as a working base if you are getting wrong results.

At a glance, what mainly seems to be missing from this PR is:

  • To make my changes work properly, I had to fix a preexisting bug in whisper preprocessing which was causing the last audio chunk to be lost during: michoecho@63b3c1e#diff-a027f93a5e0a3fe643975f0ae176db52a3330a9422857b4f6fd9bfbac134c863R384. (I haven't reported it because I'm not 100% sure it's a bug — maybe I'm not seeing something — but I'm 99% sure).
  • The ggml_permute seems to have channels and frames swapped around.
  • Qwen3-ASR uses <|audio_start|> and <|audio_end|> instead of <|audio_bos|> and <|audio_eos|>.
  • The audio encoder expects windowed (/chunked) attention, with window size between 1s and 8s. If you run the encoder with full attention on a 30s chunk, you will get bogus results. I didn't want to implement windowed attention (because I would have to learn how to do that), so I just solved it in the preprocessing layer by splitting audio into 8s chunks instead of 30s chunks.
  • As a comment in this PR acknowledges, the reference implementation runs the conv2d layers on chunks of length 100. I followed the reference implementation. I don't know either if this chunking is necessary. I didn't test the non-chunked variant.
  • The default chat template of Qwen3-ASR doesn't work with llama.cpp. (It expects the audio to be passed via some special params). (My fork doesn't care about chat templates at all either, because my application constructs the prompt directly anyway. But if you want to implement a chat template, the prompt expected by the model isn't anything fancy, it's basically just chatml with audio used as the user message).

By the way, note that Qwen3-ForcedAligner (the timestamp predictor model) has the same architecture as Qwen3-ASR, so if you implement support for the latter, you almost get support for the former too. "Almost" because the ForcedAligner is a non-autoregressive classification model. (You put in the encoded audio and the transcribed text with some <timestamp> tokens mixed in, then you run a single prediction on it, and the logits on the <timestamp> tokens will describe the timestamp at those points in the text). I'm not sure how to integrate something like that with llama.cpp's abstractions. For my private use case (generating subtitles for the Chinese lectures) I added "support" for Qwen3-ForcedAligner too, but it's too hacky to post.

@ngxson ngxson changed the title mtmd: (WIP) qwen3 audio support mtmd: qwen3 audio support (qwen3-omni and qwen3-asr) Apr 1, 2026
@ngxson ngxson marked this pull request as ready for review April 1, 2026 23:02
@ngxson ngxson requested review from a team and CISC as code owners April 1, 2026 23:02
@ngxson
Copy link
Copy Markdown
Contributor Author

ngxson commented Apr 1, 2026

both qwen3-omni and qwen3-asr are working with this PR, GGUF will be uploaded shortly

@ngxson
Copy link
Copy Markdown
Contributor Author

ngxson commented Apr 1, 2026

* To make my changes work properly, I had to fix a preexisting bug in whisper preprocessing which was causing the last audio chunk to be lost during: [michoecho@63b3c1e#diff-a027f93a5e0a3fe643975f0ae176db52a3330a9422857b4f6fd9bfbac134c863R384](https://github.com/michoecho/llama.cpp/commit/63b3c1ec0cb1f73f4cb3a7056ae7356b413452f2#diff-a027f93a5e0a3fe643975f0ae176db52a3330a9422857b4f6fd9bfbac134c863R384). (I haven't reported it because I'm not 100% sure it's a bug — maybe I'm not seeing something — but I'm 99% sure).

Chunking can be implemented via a follow-up PR, this PR processes the input as 30s chunk for simplicity

* Qwen3-ASR uses `<|audio_start|>` and `<|audio_end|>` instead of `<|audio_bos|>` and `<|audio_eos|>`.

Thanks for pointing out, that need to be fixed in this PR

* The default chat template of Qwen3-ASR doesn't work with llama.cpp. (It expects the audio to be passed via some special params). (My fork doesn't care about chat templates at all either, because my application constructs the prompt directly anyway. But if you want to implement a chat template, the prompt expected by the model isn't anything fancy, it's basically just `chatml` with audio used as the user message).

That was fixed by simply push a chatml jinja template to GGUF upon conversion

By the way, note that Qwen3-ForcedAligner (the timestamp predictor model) has the same architecture as Qwen3-ASR, so if you implement support for the latter, you almost get support for the former too. "Almost" because the ForcedAligner is a non-autoregressive classification model. (You put in the encoded audio and the transcribed text with some <timestamp> tokens mixed in, then you run a single prediction on it, and the logits on the <timestamp> tokens will describe the timestamp at those points in the text). I'm not sure how to integrate something like that with llama.cpp's abstractions. For my private use case (generating subtitles for the Chinese lectures) I added "support" for Qwen3-ForcedAligner too, but it's too hacky to post.

Hmm yeah that sounds complicated, will see if it worth implementing knowing that another model (voxtral from mistral) having somewhat same logic

Comment on lines +4298 to +4301
return []
return [(self.map_tensor_name(name), data_torch)]

return [] # skip other tensors
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return []
return [(self.map_tensor_name(name), data_torch)]
return [] # skip other tensors
return
yield from super().modify_tensors(data_torch, name, bid)
return # skip other tensors

yield from Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid)
elif "audio_tower." in name:
yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid)
return [] # skip other tensors
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return [] # skip other tensors
return # skip other tensors

Comment on lines +4942 to +4943
yield (self.map_tensor_name(name), data_torch)
return [] # skip other tensors
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
yield (self.map_tensor_name(name), data_torch)
return [] # skip other tensors
yield from super().modify_tensors(data_torch, name, bid)
return # skip other tensors

Comment on lines +5033 to +5036
if "thinker_config" in self.hparams:
vision_config = self.hparams["thinker_config"].get("vision_config", {})
else:
vision_config = self.hparams.get("vision_config", {})
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of handling this everywhere, can't we just merge in all sub-configs in thinker_config here:

if "thinker_config" in config:
# rename for Qwen2.5-Omni
config["text_config"] = config["thinker_config"]["text_config"]

return
if "visual." in name or "audio_tower." in name \
or "talker." in name or "code2wav." in name:
return []
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return []
return

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants