-
Notifications
You must be signed in to change notification settings - Fork 13.9k
Description
Name and Version
(llama-server)
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: yes
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce GTX 1660 SUPER, compute capability 7.5, VMM: yes
build: 6259 (710dfc465) with cc (GCC) 15.2.1 20250813 for x86_64-pc-linux-gnu
system info: n_threads = 5, n_threads_batch = 5, total_threads = 12
system_info: n_threads = 5 (n_threads_batch = 5) / 12 | CUDA : ARCHS = 750 | FORCE_MMQ = 1 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | FA_ALL_QUANTS = 1 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 |
Operating systems
Linux
GGML backends
CUDA
Hardware
Ryzen 5 2600 + GTX 1660 Super.
cuda v12.9.1-1
Models
https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
Problem description & steps to reproduce
(note: This happens with Mistral, but not with Gemma 3 27B's mmproj bf16)
Actual
When trying to use mmproj in bf16
https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/resolve/main/mmproj-BF16.gguf?download=true
an assertion fails during the image processing.
The f16 and f32 variants of mmproj work fine.
Expected
- im2col op for bf16 in ggml-cuda backend should be supported.
- When a required operation is unsupported, llama.cpp should fallback to the CPU backend, or, at least, fail during the model loading instead.
Although I'm not familiar with this codebase, or CUDA, or how LLMs work internally,
I've fixed this on my local branch simply by cloning im2col_cuda_f16 function, and replacing half with nv_bfloat16:
diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu
index 16bb9bec9..df9fedb84 100644
--- a/ggml/src/ggml-cuda/im2col.cu
+++ b/ggml/src/ggml-cuda/im2col.cu
@@ -73,15 +73,22 @@ static void im2col_cuda_f32(const float * x, float * dst,
im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
+static void im2col_cuda_bf16(const float * x, nv_bfloat16 * dst,
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+ int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+
+ im2col_cuda<nv_bfloat16>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
+}
+
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src1_d = (const float *)src1->data;
- float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_BF16);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@@ -107,8 +114,10 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
if(dst->type == GGML_TYPE_F16) {
- im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
+ im2col_cuda_f16(src1_d, (half *) dst->data, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
+ } else if(dst->type == GGML_TYPE_F32) {
+ im2col_cuda_f32(src1_d, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
} else {
- im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
+ im2col_cuda_bf16(src1_d, (nv_bfloat16 *) dst->data, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
}I'm not going to send a pull request because I doubt I'd be able to create good enough tests in tests/test-backend-ops.cpp, sorry.
First Bad Commit
No response
Relevant log output
main: server is listening on http://127.0.0.1:8080 - starting the main loop
srv update_slots: all slots are idle
srv params_from_: Chat format: Content-only
slot launch_slot_: id 0 | task 0 | processing task
slot update_slots: id 0 | task 0 | new prompt, n_ctx_slot = 12032, n_keep = 0, n_prompt_tokens = 1360
slot update_slots: id 0 | task 0 | kv cache rm [0, end)
slot update_slots: id 0 | task 0 | prompt processing progress, n_past = 512, n_tokens = 512, progress = 0.376471
slot update_slots: id 0 | task 0 | kv cache rm [512, end)
slot update_slots: id 0 | task 0 | prompt processing progress, n_past = 556, n_tokens = 44, progress = 0.408824
slot update_slots: id 0 | task 0 | kv cache rm [556, end)
encoding image slice...
srv process_chun: processing image...
/home/arzeth/llama.cpp-cuda/ggml/src/ggml-cuda/im2col.cu:84: GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32) failed
SIGINT is used by the debugger.
Are you sure you want to change it? (y or n) [answered Y; input not from terminal]
0x00007f471e80b042 in ?? () from /usr/lib/libc.so.6
#0 0x00007f471e80b042 in ?? () from /usr/lib/libc.so.6
#1 0x00007f471e7ff1ac in ?? () from /usr/lib/libc.so.6
#2 0x00007f471e7ff1f4 in ?? () from /usr/lib/libc.so.6
#3 0x00007f471e86fdcf in wait4 () from /usr/lib/libc.so.6
#4 0x00007f471ed6385b in ggml_print_backtrace () at /home/arzeth/llama.cpp-cuda/ggml/src/ggml.c:196
196 waitpid(child_pid, NULL, 0);
#5 0x00007f471ed639c5 in ggml_abort (file=0x7f47217f1de8 "/home/arzeth/llama.cpp-cuda/ggml/src/ggml-cuda/im2col.cu", line=0x54, fmt=0x7f4721838202 "GGML_ASSERT(%s) failed") at /home/arzeth/llama.cpp-cuda/ggml/src/ggml.c:230
230 ggml_print_backtrace();
#6 0x00007f472130c44b in ggml_cuda_op_im2col(ggml_backend_cuda_context&, ggml_tensor*) () from /home/arzeth/llama.cpp-cuda/buildCu/bin/libggml-cuda.so
#7 0x00007f4721305c2d in ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) () from /home/arzeth/llama.cpp-cuda/buildCu/bin/libggml-cuda.so
#8 0x00007f471ed7ec97 in ggml_backend_sched_compute_splits (sched=0x2a) at /home/arzeth/llama.cpp-cuda/ggml/src/ggml-backend.cpp:1487
1487 enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
#9 ggml_backend_sched_graph_compute_async (sched=sched@entry=0x334de283000, graph=<optimized out>) at /home/arzeth/llama.cpp-cuda/ggml/src/ggml-backend.cpp:1681
1681 return ggml_backend_sched_compute_splits(sched);
#10 0x00007f471ed7f14e in ggml_backend_sched_graph_compute (sched=0x334de283000, graph=<optimized out>) at /home/arzeth/llama.cpp-cuda/ggml/src/ggml-backend.cpp:1665
1665 enum ggml_status err = ggml_backend_sched_graph_compute_async(sched, graph);
#11 0x00007f47222240da in clip_image_batch_encode (ctx=ctx@entry=0x334de282100, n_threads=0x5, imgs_c_ptr=imgs_c_ptr@entry=0x334ee1a06f0, vec=0x334f0d50000) at /usr/include/c++/15.2.1/bits/unique_ptr.h:193
193 pointer _M_ptr() const noexcept { return std::get<0>(_M_t); }
#12 0x00007f4722210e87 in mtmd_encode (ctx=0x334d5ee0cc0, image_tokens=0x334ee1a06e0) at /usr/include/c++/15.2.1/bits/stl_vector.h:1395
1395 data() _GLIBCXX_NOEXCEPT
#13 0x00007f47222b1769 in mtmd_helper_eval_chunk_single (ctx=0x334d5ee0cc0, lctx=0x334de261b00, chunk=0x334ee080740, n_past=0x22c, seq_id=0x0, n_batch=<optimized out>, logits_last=0x1, new_n_past=0x7ffe826b9b50) at /home/arzeth/llama.cpp-cuda/tools/mtmd/mtmd-helper.cpp:294
294 ret = mtmd_encode_chunk(ctx, chunk);
#14 0x0000560994b167ea in ?? ()
#15 0x0000560994ae0a74 in ?? ()
#16 0x0000560994a9a0ac in ?? ()
#17 0x00007f471e793675 in ?? () from /usr/lib/libc.so.6
#18 0x00007f471e793729 in __libc_start_main () from /usr/lib/libc.so.6
#19 0x0000560994a9bf05 in ?? ()