Skip to content

Commit cf4f5d2

Browse files
committed
various fixes
1 parent cf38b47 commit cf4f5d2

File tree

10 files changed

+107
-32
lines changed

10 files changed

+107
-32
lines changed

.editorconfig

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,7 @@ end_of_line = unset
4848
charset = unset
4949
trim_trailing_whitespace = unset
5050
insert_final_newline = unset
51+
52+
[tools/mtmd/miniaudio.h]
53+
trim_trailing_whitespace = unset
54+
insert_final_newline = unset

convert_hf_to_gguf.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,8 @@ class VisionModel(ModelBase):
11191119
model_arch = gguf.MODEL_ARCH.CLIP_VISION
11201120
preprocessor_config: dict[str, Any]
11211121
global_config: dict[str, Any]
1122+
has_vision_encoder: bool = True
1123+
has_audio_encoder: bool = False
11221124

11231125
def __init__(self, *args, **kwargs):
11241126
super().__init__(*args, **kwargs)
@@ -1159,7 +1161,10 @@ def set_type(self):
11591161
def set_gguf_parameters(self):
11601162
self.gguf_writer.add_file_type(self.ftype)
11611163
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
1162-
self.gguf_writer.add_vision_has_vision_encoder(True)
1164+
if self.has_vision_encoder:
1165+
self.gguf_writer.add_vision_has_vision_encoder(True)
1166+
if self.has_audio_encoder:
1167+
self.gguf_writer.add_vision_has_audio_encoder(True)
11631168

11641169
# vision config
11651170
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
@@ -5969,6 +5974,7 @@ def _reverse_hf_permute(data_torch, n_heads, hidden_dim):
59695974
@ModelBase.register("UltravoxModel")
59705975
class UltravoxModel(TextModel):
59715976
model_arch = gguf.MODEL_ARCH.LLAMA # dummy
5977+
59725978
def __init__(self, *args, **kwargs):
59735979
super().__init__(*args, **kwargs)
59745980
raise NotImplementedError("Ultravox does not have text decoder. Please use --mmproj argument")
@@ -5978,6 +5984,8 @@ def __init__(self, *args, **kwargs):
59785984
class UltravoxAudioModel(VisionModel):
59795985
def __init__(self, *args, **kwargs):
59805986
super().__init__(*args, **kwargs)
5987+
self.has_vision_encoder = False
5988+
self.has_audio_encoder = True
59815989
self.hparams["image_size"] = self.hparams["num_mel_bins"]
59825990
self.hparams["patch_size"] = self.hparams["num_mel_bins"]
59835991
self.hparams["hidden_size"] = self.hparams["d_model"]
@@ -5988,7 +5996,6 @@ def __init__(self, *args, **kwargs):
59885996

59895997
def set_gguf_parameters(self):
59905998
super().set_gguf_parameters()
5991-
self.gguf_writer.add_bool(gguf.Keys.ClipVision.HAS_AUDIO_ENC, True)
59925999
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.ULTRAVOX)
59936000
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
59946001
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.Projector.STACK_FACTOR, self.global_config["stack_factor"])
@@ -5998,7 +6005,7 @@ def tensor_force_quant(self, name, new_name, bid, n_dims):
59986005
if ".conv" in name and ".weight" in name:
59996006
return gguf.GGMLQuantizationType.F16
60006007
return False
6001-
6008+
60026009
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
60036010
del bid # unused
60046011

gguf-py/gguf/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,9 @@ class Adapter:
220220
LORA_ALPHA = "adapter.lora.alpha"
221221

222222
class ClipVision:
223-
HAS_AUDIO_ENC = "clip.has_audio_encoder"
224223
PROJECTOR_TYPE = "clip.projector_type"
225224
HAS_VISION_ENCODER = "clip.has_vision_encoder"
225+
HAS_AUDIO_ENCODER = "clip.has_audio_encoder"
226226
HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
227227
IMAGE_SIZE = "clip.vision.image_size"
228228
PATCH_SIZE = "clip.vision.patch_size"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,9 @@ def add_vision_projection_dim(self, value: int) -> None:
942942
def add_vision_has_vision_encoder(self, value: bool) -> None:
943943
self.add_bool(Keys.ClipVision.HAS_VISION_ENCODER, value)
944944

945+
def add_vision_has_audio_encoder(self, value: bool) -> None:
946+
self.add_bool(Keys.ClipVision.HAS_AUDIO_ENCODER, value)
947+
945948
def add_vision_patch_size(self, value: int) -> None:
946949
self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
947950

tools/mtmd/clip-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#define KEY_NAME "general.name"
1818
#define KEY_DESCRIPTION "general.description"
1919
#define KEY_HAS_AUDIO_ENC "clip.has_audio_encoder"
20+
#define KEY_HAS_VISION_ENC "clip.has_vision_encoder"
2021
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
2122
#define KEY_USE_GELU "clip.use_gelu"
2223
#define KEY_USE_SILU "clip.use_silu"

tools/mtmd/clip.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ enum patch_merge_type {
165165
};
166166

167167
struct clip_hparams {
168+
bool has_vision = false;
168169
bool has_audio = false;
169170

170171
int32_t image_size;
@@ -2029,15 +2030,16 @@ struct clip_model_loader {
20292030
{
20302031
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
20312032

2032-
get_bool(KEY_HAS_AUDIO_ENC, hparams.has_audio, false);
2033-
get_u32(KEY_N_EMBD, hparams.n_embd);
2034-
get_u32(KEY_N_HEAD, hparams.n_head);
2035-
get_u32(KEY_N_FF, hparams.n_ff);
2036-
get_u32(KEY_N_BLOCK, hparams.n_layer);
2037-
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
2038-
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
2039-
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
2040-
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
2033+
get_bool(KEY_HAS_AUDIO_ENC, hparams.has_audio, false);
2034+
get_bool(KEY_HAS_VISION_ENC, hparams.has_vision, false);
2035+
get_u32(KEY_N_EMBD, hparams.n_embd);
2036+
get_u32(KEY_N_HEAD, hparams.n_head);
2037+
get_u32(KEY_N_FF, hparams.n_ff);
2038+
get_u32(KEY_N_BLOCK, hparams.n_layer);
2039+
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
2040+
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
2041+
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
2042+
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
20412043
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
20422044
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
20432045

@@ -2173,6 +2175,7 @@ struct clip_model_loader {
21732175
}
21742176

21752177
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
2178+
LOG_INF("%s: has_vision_encoder: %d\n", __func__, hparams.has_vision);
21762179
LOG_INF("%s: has_audio_encoder: %d\n", __func__, hparams.has_audio);
21772180
LOG_INF("%s: n_embd: %d\n", __func__, hparams.n_embd);
21782181
LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head);
@@ -3953,6 +3956,14 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
39533956
return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
39543957
}
39553958

3959+
bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
3960+
return ctx->vision_model.hparams.has_vision;
3961+
}
3962+
3963+
bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
3964+
return ctx->vision_model.hparams.has_audio;
3965+
}
3966+
39563967
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
39573968
clip_image_f32 clip_img;
39583969
clip_img.buf.resize(h * w * 3);

tools/mtmd/clip.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,6 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
9696

9797
// use by audio input
9898
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_step, float * mel);
99+
100+
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
101+
bool clip_has_audio_encoder(const struct clip_ctx * ctx);

tools/mtmd/mtmd-audio.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929

3030
// most of the code here is copied from whisper.cpp
3131

32+
// align x to upper multiple of n
33+
#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
34+
3235
namespace whisper_preprocessor {
3336

3437
#define SIN_COS_N_COUNT WHISPER_N_FFT
@@ -298,9 +301,17 @@ bool preprocess_audio(
298301
size_t n_samples,
299302
whisper_filters & filters,
300303
whisper_mel & output) {
304+
305+
// a bit hacky, but we want to align the output to a multiple of WHISPER_N_FFT * proj_stack_factor
306+
// proj_stack_factor is 8, specifically for Ultravox (so this is a temporary solution)
307+
308+
size_t n_padded = _ALIGN(n_samples, WHISPER_N_FFT * 8);
309+
std::vector<float> samples_padded(n_padded, 0.0f);
310+
std::copy(samples, samples + n_samples, samples_padded.data());
311+
301312
return log_mel_spectrogram(
302-
samples,
303-
n_samples,
313+
samples_padded.data(),
314+
samples_padded.size(),
304315
COMMON_SAMPLE_RATE,
305316
WHISPER_N_FFT,
306317
WHISPER_HOP_LENGTH,
@@ -391,7 +402,7 @@ bool read_wav_from_buf(const unsigned char * buf_in, size_t len, int target_samp
391402
ma_decoder_uninit(&decoder);
392403
return false;
393404
}
394-
405+
395406
double resample_ratio = (double)target_sampler_rate / decoder.outputSampleRate;
396407
// Reserve for mono output
397408
pcmf32_mono.reserve(static_cast<size_t>(total_frames_expected_from_decoder * resample_ratio * 1.1) + 1);
@@ -411,9 +422,9 @@ bool read_wav_from_buf(const unsigned char * buf_in, size_t len, int target_samp
411422
}
412423

413424
if (frames_decoded_this_iteration == 0 && result == MA_AT_END) { // Ensure we process the last bit if MA_AT_END was from previous read
414-
break;
425+
break;
415426
}
416-
427+
417428
ma_uint64 frame_count_in = frames_decoded_this_iteration;
418429
ma_uint64 frame_count_out_capacity;
419430

@@ -423,7 +434,7 @@ bool read_wav_from_buf(const unsigned char * buf_in, size_t len, int target_samp
423434
ma_decoder_uninit(&decoder);
424435
return false;
425436
}
426-
437+
427438
size_t current_pcmf32_sample_offset = pcmf32_mono.size();
428439
// Resize for mono output (channelsOut is 1)
429440
pcmf32_mono.resize(current_pcmf32_sample_offset + frame_count_out_capacity * data_converter.channelsOut);
@@ -433,7 +444,7 @@ bool read_wav_from_buf(const unsigned char * buf_in, size_t len, int target_samp
433444
result = ma_data_converter_process_pcm_frames(
434445
&data_converter,
435446
temp_decode_buffer.data(),
436-
&frame_count_in,
447+
&frame_count_in,
437448
pcmf32_mono.data() + current_pcmf32_sample_offset,
438449
&frames_actually_output
439450
);
@@ -443,7 +454,7 @@ bool read_wav_from_buf(const unsigned char * buf_in, size_t len, int target_samp
443454
ma_decoder_uninit(&decoder);
444455
return false;
445456
}
446-
457+
447458
// Adjust size to actual frames output (mono)
448459
pcmf32_mono.resize(current_pcmf32_sample_offset + frames_actually_output * data_converter.channelsOut);
449460

tools/mtmd/mtmd.cpp

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ struct mtmd_context {
9797
bool print_timings;
9898
int n_threads;
9999
std::string image_marker;
100+
bool has_vision;
101+
bool has_audio;
100102

101103
// for llava-uhd style models, we need special tokens in-between slices
102104
// minicpmv calls them "slices", llama 4 calls them "tiles"
@@ -135,7 +137,9 @@ struct mtmd_context {
135137
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
136138
}
137139

138-
use_mrope = clip_is_qwen2vl(ctx_clip);
140+
has_vision = clip_has_vision_encoder(ctx_clip);
141+
has_audio = clip_has_audio_encoder(ctx_clip);
142+
use_mrope = clip_is_qwen2vl(ctx_clip);
139143

140144
projector_type proj = clip_get_projector_type(ctx_clip);
141145
int minicpmv_version = clip_is_minicpmv(ctx_clip);
@@ -362,15 +366,24 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
362366
output->entries.emplace_back(std::move(chunk));
363367

364368
// only add image/audio tokens to middle of 2 parts
365-
bool is_not_last = &parts.back() != &part;
369+
// therefore, we skip handling image/audio if this is the last part
370+
if (&parts.back() == &part) {
371+
continue;
372+
}
373+
374+
if (!bitmaps[i_bm]->is_audio) {
375+
// handle image
366376

367-
// handle image
368-
if (is_not_last && !bitmaps[i_bm]->is_audio) {
369377
if (i_bm >= n_bitmaps) {
370378
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
371379
return 1;
372380
}
373381

382+
if (!ctx->has_vision) {
383+
LOG_ERR("%s: error: model does not support vision input\n", __func__);
384+
return 2;
385+
}
386+
374387
// convert mtmd_bitmap to clip_image_u8
375388
clip_image_u8_ptr img_u8(clip_image_u8_init());
376389
img_u8->nx = bitmaps[i_bm]->nx;
@@ -486,15 +499,20 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
486499

487500
i_bm++; // move to next image
488501
continue;
489-
}
490-
491-
// handle audio
492-
if (is_not_last && bitmaps[i_bm]->is_audio) {
502+
503+
} else {
504+
// handle audio
505+
493506
if (i_bm >= n_bitmaps) {
494507
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
495508
return 1;
496509
}
497510

511+
if (!ctx->has_audio) {
512+
LOG_ERR("%s: error: model does not support audio input\n", __func__);
513+
return 2;
514+
}
515+
498516
// preprocess audio
499517
whisper_preprocessor::whisper_mel mel_spec;
500518
GGML_ASSERT(ctx->w_filters.n_mel);
@@ -506,9 +524,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
506524
return 2;
507525
}
508526

509-
// DEBUG!!!!!!!!!!
510-
printf("mel_spec.n_len = %d\n", mel_spec.n_len);
511-
printf("mel_spec.n_mel = %d\n", mel_spec.n_mel);
527+
// DEBUG!!!
528+
// mel_spec.data.resize(220*8*2 * mel_spec.n_mel);
529+
// mel_spec.n_len = 220*8*2;
530+
LOG_DBG("mel_spec.n_len = %d\n", mel_spec.n_len);
531+
LOG_DBG("mel_spec.n_mel = %d\n", mel_spec.n_mel);
512532

513533
// convert mel spectrogram to clip_image_f32_batch
514534
clip_image_f32_ptr mel_f32(clip_image_f32_init());
@@ -526,6 +546,8 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
526546
audio_tokens->batch_f32 = std::move(batch_f32);
527547
audio_tokens->id = bitmaps[i_bm]->id; // optional
528548

549+
LOG_DBG("audio_tokens->n_tokens = %d\n", audio_tokens->n_tokens);
550+
529551
mtmd_input_chunk chunk{
530552
MTMD_INPUT_CHUNK_TYPE_AUDIO,
531553
{}, // text tokens
@@ -606,6 +628,14 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) {
606628
return ctx->use_mrope;
607629
}
608630

631+
bool mtmd_support_vision(mtmd_context * ctx) {
632+
return ctx->has_vision;
633+
}
634+
635+
bool mtmd_support_audio(mtmd_context * ctx) {
636+
return ctx->has_audio;
637+
}
638+
609639
// these 2 helpers below use internal clip_image_u8_ptr,
610640
// so unfortunately they cannot moved to mtmd-helper.h
611641
// however, in theory, user can decode image file to bitmap using

tools/mtmd/mtmd.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
9999
// whether the current model use M-RoPE for llama_decode
100100
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
101101

102+
// whether the current model supports vision input
103+
MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
104+
105+
// whether the current model supports audio input
106+
MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
102107

103108
// mtmd_bitmap
104109
//

0 commit comments

Comments
 (0)