mtmd: add Gemma 4 audio conformer encoder support#21421
mtmd: add Gemma 4 audio conformer encoder support#21421stephencox-ict wants to merge 3 commits intoggml-org:masterfrom
Conversation
|
Hi @stephencox-ict, thanks for your contribution! Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:
Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below. |
|
Nice, seems to work but not 100% correct (using e4b, f16):
However, the correct transcription should be:
|
I haven't yet implemented chunked local self-attention. Focussed on the testing side now and will come back to this |
6bf9d4a to
9729486
Compare
83d1f37 to
13e9f5e
Compare
29dd32e to
7435a59
Compare
JohannesGaessler
left a comment
There was a problem hiding this comment.
The changes to test-llama-archs.cpp LGTM otherwise. For some of the other files I'm seeing though that you are adding code comments with EM dashes. Please stick to ASCII unless there is a good reason not to.
9a5b23a to
1cbecb4
Compare
Fixed |
|
I've tested transcription with the E4B model, it seems it struggles with longer prompts (I tested with a 20 seconds audio in French), only transcribing near the end of the audio. Also crashes with CUDA because of a missing kernel, but it's an one line patch that should be in a separate PR. --- a/ggml/src/ggml-cuda/ssm-conv.cu
+++ b/ggml/src/ggml-cuda/ssm-conv.cu
@@ -134,8 +134,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
switch (nc) {
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
+ case 5: launch_kernel(std::integral_constant<int, 5>{}); break;
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
- default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
+ default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now.");
}
}I suspect that the missing chunked attention might be the culprit. |
81b0202 to
f3b827d
Compare
|
Looking into chunked encoding and cuda issue |
There is some instabilty I'm looking into. |
f3b827d to
6e971d6
Compare
6e971d6 to
883dfbc
Compare
|
All looks fixed now. Stability issues was some unbounded limits. Tested both e2b and e4b on Cuda and worked well |
f0484a7 to
8a1494c
Compare
|
Hmm, I still get repetitive text with E4B model:
My command: |
|
Btw, can you prevent force-pushing to this PR? Force-pushing make it hard to keep track of line-level changes |
Thanks for that. Reproduced with that file. Looking into it. Noted about commits |
Bug Found:
|
Add audio processing for Gemma 4 E2B/E4B via a USM-style Conformer. Architecture: - 12-layer Conformer: FFN → Self-Attention → Causal Conv1D → FFN → Norm - Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm - Full self-attention with sinusoidal RPE and sliding window mask (24) - Logit softcapping at 50.0, ClippableLinear clamping - Output: 1024 → 1536 → RMSNorm → multimodal embedder Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a): - HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3 - Standard periodic Hann window (320 samples), zero-padded to FFT size - Semicausal left-padding (frame_length/2 samples) - Frame count matched to PyTorch (unfold formula) - No pre-emphasis, no Whisper-style normalization - Mel cosine similarity vs PyTorch: 0.9998 Key fixes: - Tensor loading dedup: prevent get_tensor() from creating duplicate entries in ctx_data. Fixed with std::set guard. - ClippableLinear clamp_info loading moved after per-layer tensors. - Sliding window mask (24 positions) matching PyTorch context_size. - Skip Whisper normalization for Gemma4 mel output. Tested on E2B and E4B with CPU and Vulkan backends. Transcribes: "Glad to see things are going well and business is starting to pick up" (matching ground truth). Ref: ggml-org#21325 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Audio encoder fixes: - Fix swapped conv norm weight mapping in tensor_mapping.py (A_ENC_CONV_NORM and A_ENC_NORM_CONV had their gemma4 entries inverted, causing the conv pre-norm and internal norm weights to be swapped in GGUF. This produced 0.67 encoder cosine vs PyTorch; now 0.9999) - Fix causal mask off-by-one: add (gq - gk) < max_past to match PyTorch's dist < left_window_size (was attending to 13 past tokens instead of 12) - Use -1e9 instead of -INFINITY for masked positions to match PyTorch's attention_invalid_logits_value and avoid NaN in padded attention weights LM fixes: - Disable attention logit softcapping for Gemma4 (unlike Gemma2, Gemma4's text model does not use attn softcapping; was incorrectly hardcoded) - Use BF16-rounded embedding scale constants to match PyTorch's native BF16 training precision (ref: PR ggml-org#21451). Fixes long-context coherence on CPU/Vulkan backends. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
8a1494c to
2852c7c
Compare
Use double-precision trig (sin/cos) instead of float (sinf/cosf) for precomputed FFT twiddle factors, Hann window, and sinusoidal RPE to match PyTorch's precision in the audio encoder preprocessing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
I'm still busy tracing the divergence for the test-2.mp3 file. Making progress, but takes a bit of time |
Overview
Add audio processing support for Gemma 4 models via a USM-style Conformer encoder.
Architecture:
Chunked local attention (matching PyTorch reference):
Mel preprocessing (dedicated `mtmd_audio_preprocessor_gemma4a`):
Fixes (beyond the initial encoder):
Conv norm weight mapping (`tensor_mapping.py`): `A_ENC_CONV_NORM` and `A_ENC_NORM_CONV` had their Gemma4 entries inverted, swapping conv pre-norm and internal norm weights. Encoder cosine improved from 0.67 → 0.9999.
Causal mask off-by-one (`clip.cpp`): Added `(gq - gk) < max_past` to match PyTorch's `dist < left_window_size` (was attending to 13 past tokens instead of 12).
Mask invalid value (`clip.cpp`): Use `-1e9` instead of `-INFINITY` for masked positions to match PyTorch's `attention_invalid_logits_value`.
Double-precision preprocessing (`mtmd-audio.cpp`, `clip.cpp`): Use double-precision trig for FFT twiddle factors, Hann window, and sinusoidal RPE computation.
Attention softcapping (`llama-model.cpp`): Gemma4's text model does NOT use attention logit softcapping (unlike Gemma2). Was incorrectly hardcoded to `true` with default value 50.0.
BF16 precision rounding (`gemma4-iswa.cpp`): Use BF16-rounded embedding scale constants to reduce divergence from PyTorch's native BF16 training precision (ref: PR Gemma 4: move some computations to BF16 #21451).
Test results (E2B Q4_K_M):
Short audio (5.9s LibriSpeech) - works on CPU, Vulkan, and CUDA:
```
Ground truth: "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL"
Output: "Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
```
Known limitation: Longer audio (17s+) still produces repetitive output. The audio encoder output is correct (0.999 cosine vs PyTorch across all 12 layers + output projection) but the LM enters thinking mode and loops. This appears to be an upstream LM precision issue — PyTorch FP32 transcribes correctly with the same encoder output. See PR #21451 for the full BF16 computation fix needed on the LM side.
Generation parameters (from model's `generation_config.json`):
`--temp 1.0 --top-k 64 --top-p 0.95`
Additional information
Test plan:
Ref: #21325
Requirements