Skip to content

Commit d587602

Browse files
committed
Merge commit 'f4586ee5986d6f965becb37876d6f3666478a961' into concedo_experimental
# Conflicts: # README.md # docs/multimodal/minicpmo2.6.md # docs/multimodal/minicpmv2.6.md # ggml/src/ggml-cann/aclnn_ops.cpp # ggml/src/ggml-cann/ggml-cann.cpp # ggml/src/ggml-cpu/kleidiai/kleidiai.cpp # ggml/src/ggml-cuda/CMakeLists.txt # ggml/src/ggml-opencl/ggml-opencl.cpp # ggml/src/ggml-opencl/kernels/add.cl # ggml/src/ggml-sycl/ggml-sycl.cpp # tools/perplexity/perplexity.cpp # tools/server/README.md
2 parents 302bb8c + f4586ee commit d587602

File tree

17 files changed

+671
-336
lines changed

17 files changed

+671
-336
lines changed

common/arg.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2951,11 +2951,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
29512951
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
29522952
"(default: auto)",
29532953
[](common_params & params, const std::string & value) {
2954-
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
2955-
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
2956-
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
2957-
else if (value == "auto") { params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; }
2958-
else { throw std::invalid_argument("invalid value"); }
2954+
params.reasoning_format = common_reasoning_format_from_name(value);
29592955
}
29602956
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
29612957
add_opt(common_arg(

common/chat.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,17 @@ common_chat_templates_ptr common_chat_templates_init(
552552
default_template_src = CHATML_TEMPLATE_SRC;
553553
}
554554
}
555+
556+
// TODO @ngxson : this is a temporary hack to prevent chat template from throwing an error
557+
// Ref: https://github.com/ggml-org/llama.cpp/pull/15230#issuecomment-3173959633
558+
if (default_template_src.find("<|channel|>") != std::string::npos
559+
// search for the error message and patch it
560+
&& default_template_src.find("in message.content or") != std::string::npos) {
561+
string_replace_all(default_template_src,
562+
"{%- if \"<|channel|>analysis<|message|>\" in message.content or \"<|channel|>final<|message|>\" in message.content %}",
563+
"{%- if false %}");
564+
}
565+
555566
std::string token_bos = bos_token_override;
556567
std::string token_eos = eos_token_override;
557568
bool add_bos = false;
@@ -625,6 +636,19 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
625636
}
626637
}
627638

639+
common_reasoning_format common_reasoning_format_from_name(const std::string & format) {
640+
if (format == "none") {
641+
return COMMON_REASONING_FORMAT_NONE;
642+
} else if (format == "auto") {
643+
return COMMON_REASONING_FORMAT_AUTO;
644+
} else if (format == "deepseek") {
645+
return COMMON_REASONING_FORMAT_DEEPSEEK;
646+
} else if (format == "deepseek-legacy") {
647+
return COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY;
648+
}
649+
throw std::runtime_error("Unknown reasoning format: " + format);
650+
}
651+
628652
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
629653
std::string arguments;
630654
if (builder.is_partial()) {

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ std::string common_chat_format_example(
191191

192192
const char* common_chat_format_name(common_chat_format format);
193193
const char* common_reasoning_format_name(common_reasoning_format format);
194+
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
194195
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
195196

196197
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);

convert_hf_to_gguf.py

Lines changed: 255 additions & 99 deletions
Large diffs are not rendered by default.

convert_lora_to_gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
340340
sys.exit(1)
341341
else:
342342
logger.info(f"Loading base model: {dir_base_model.name}")
343-
hparams = ModelBase.load_hparams(dir_base_model)
343+
hparams = ModelBase.load_hparams(dir_base_model, False)
344344

345345
with torch.inference_mode():
346346
try:

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,11 @@ static bool turing_mma_available(const int cc) {
316316
}
317317

318318
static bool ampere_mma_available(const int cc) {
319-
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
319+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
320320
}
321321

322322
static bool cp_async_available(const int cc) {
323-
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
323+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
324324
}
325325

326326
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 152 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,117 @@
1+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
2+
#define USE_CUB
3+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
4+
5+
#ifdef USE_CUB
6+
#include <cub/cub.cuh>
7+
using namespace cub;
8+
#endif // USE_CUB
9+
110
#include "ssm-scan.cuh"
211

3-
template <size_t splitD, size_t N>
4-
__global__ void __launch_bounds__(splitD, 2)
5-
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
6-
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
12+
// We would like to keep pragma unroll for cases where L_template is not 0,
13+
// so we suppress the clang transformation warning.
14+
#ifdef __clang__
15+
#pragma clang diagnostic push
16+
#pragma clang diagnostic ignored "-Wpass-failed"
17+
#endif // __clang__
18+
template <size_t splitD, size_t N, size_t L_template>
19+
__global__ void __launch_bounds__(splitD, 1)
20+
ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
21+
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
722
const int32_t * __restrict__ src6, float * __restrict__ dst,
823
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
924
const int src2_nb1, const int src2_nb2, const int src3_nb1,
1025
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
11-
const int64_t s_off, const int64_t d_inner, const int64_t L) {
12-
13-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
14-
const int bidx = blockIdx.x; // split along B (sequences)
15-
const int bidy = blockIdx.y; // split along D (d_inner)
16-
const int tid = threadIdx.x;
17-
const int wid = tid / 32;
18-
const int wtid = tid % 32;
19-
20-
extern __shared__ float smem[];
21-
const int stride_sA = N + 1;
22-
const int stride_ss0 = N + 1;
23-
float * smem_A = smem;
24-
float * smem_s0 = smem_A + splitD * stride_sA;
25-
26-
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
27-
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
28-
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
29-
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
30-
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
31-
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
32-
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
33-
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
34-
35-
const int stride_s0 = src0_nb2 / sizeof(float);
36-
const int stride_x = src1_nb2 / sizeof(float);
26+
const int64_t s_off, const int64_t d_inner, const int64_t L_param)
27+
{
28+
const size_t L = L_template == 0 ? L_param : L_template;
29+
const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
30+
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
31+
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
32+
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
33+
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
34+
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
35+
float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
36+
float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
37+
38+
const int stride_x = src1_nb2 / sizeof(float);
3739
const int stride_dt = src2_nb1 / sizeof(float);
38-
const int stride_A = src3_nb1 / sizeof(float);
39-
const int stride_B = src4_nb2 / sizeof(float);
40-
const int stride_C = src5_nb2 / sizeof(float);
41-
const int stride_s = stride_s0;
42-
const int stride_y = d_inner;
40+
const int stride_B = src4_nb2 / sizeof(float);
41+
const int stride_C = src5_nb2 / sizeof(float);
42+
const int stride_y = d_inner;
4343

44-
// can N not be 16? for example 32?
45-
if (N == 16) {
46-
#pragma unroll
47-
for (size_t i = 0; i < splitD / 4; i += 2) {
48-
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
49-
// todo: bank conflict
50-
// I am always confused with how to use the swizzling method to solve
51-
// bank conflit. Hoping somebody can tell me.
52-
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
53-
}
44+
float regA[N];
45+
float regs0[N];
46+
47+
__shared__ float smemB[N];
48+
__shared__ float smemC[N];
49+
50+
#ifdef USE_CUB
51+
using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
52+
using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
53+
54+
union CubTempStorage {
55+
typename BlockLoad::TempStorage load_temp;
56+
typename BlockStore::TempStorage store_temp;
57+
};
58+
__shared__ CubTempStorage cub_temp_storage;
59+
60+
BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
61+
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
62+
#else
63+
const int stride_s0 = src0_nb2 / sizeof(float);
64+
const int stride_A = src3_nb1 / sizeof(float);
5465
#pragma unroll
55-
for (size_t i = 0; i < splitD / 4; i += 2) {
56-
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
57-
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
58-
}
66+
for (size_t n = 0; n < N; ++n)
67+
{
68+
regA[n] = A_block[threadIdx.x * stride_A + n];
69+
regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
5970
}
71+
#endif
6072

61-
__syncthreads();
73+
#pragma unroll
74+
for (size_t i = 0; i < L; i++)
75+
{
76+
if (threadIdx.x < N)
77+
{
78+
smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
79+
smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
80+
}
81+
__syncthreads();
6282

63-
for (int64_t i = 0; i < L; i++) {
64-
float dt_soft_plus = dt_block[i * stride_dt + tid];
65-
if (dt_soft_plus <= 20.0f) {
66-
dt_soft_plus = log1pf(exp(dt_soft_plus));
83+
float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
84+
if (dt_soft_plus <= 20.0f)
85+
{
86+
dt_soft_plus = log1pf(expf(dt_soft_plus));
6787
}
68-
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
88+
float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
89+
6990
float sumf = 0.0f;
7091
#pragma unroll
71-
for (size_t j = 0; j < N; j++) {
72-
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
73-
(B_block[i * stride_B + j] * x_dt);
74-
sumf += state * C_block[i * stride_C + j];
75-
if (i == L - 1) {
76-
s_block[tid * stride_s + j] = state;
77-
} else {
78-
smem_s0[tid * stride_ss0 + j] = state;
79-
}
92+
for (size_t n = 0; n < N; n++)
93+
{
94+
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
95+
sumf += state * smemC[n];
96+
regs0[n] = state;
8097
}
81-
__syncthreads();
82-
y_block[i * stride_y + tid] = sumf;
98+
y_block[i * stride_y + threadIdx.x] = sumf;
8399
}
100+
101+
#ifdef USE_CUB
102+
BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
103+
#else
104+
const int stride_s = stride_s0;
105+
#pragma unroll
106+
for (size_t n = 0; n < N; ++n)
107+
{
108+
s_block[threadIdx.x * stride_s + n] = regs0[n];
109+
}
110+
#endif
84111
}
112+
#ifdef __clang__
113+
#pragma clang diagnostic pop
114+
#endif // __clang__
85115

86116
// assumes as many threads as d_state
87117
template <int splitH, int d_state>
@@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
201231
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202232
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
203233
cudaStream_t stream) {
234+
const int threads = 128;
204235
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
205236
if (src3_nb1 == sizeof(float)) {
206237
// Mamba-2
207238
if (d_state == 128) {
208-
const int threads = 128;
209239
GGML_ASSERT(d_state % threads == 0);
210240
// NOTE: can be any power of two between 4 and 64
211241
const int splitH = 16;
@@ -229,18 +259,70 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
229259
GGML_ABORT("doesn't support d_state!=(128 or 256).");
230260
}
231261
} else {
232-
const int threads = 128;
233262
// Mamba-1
234263
GGML_ASSERT(n_head % threads == 0);
235264
GGML_ASSERT(head_dim == 1);
236265
GGML_ASSERT(n_group == 1);
237266
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
238267
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
239268
if (d_state == 16) {
240-
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
241-
src0, src1, src2, src3, src4, src5, src6, dst,
269+
switch (n_tok)
270+
{
271+
case 1:
272+
ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
273+
src0, src1, src2, src3, src4, src5, src6, dst,
274+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
275+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
276+
break;
277+
case 2:
278+
ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
279+
src0, src1, src2, src3, src4, src5, src6, dst,
280+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
281+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
282+
break;
283+
case 3:
284+
ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
285+
src0, src1, src2, src3, src4, src5, src6, dst,
286+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
287+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
288+
break;
289+
case 4:
290+
ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
291+
src0, src1, src2, src3, src4, src5, src6, dst,
292+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
293+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
294+
break;
295+
case 5:
296+
ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
297+
src0, src1, src2, src3, src4, src5, src6, dst,
242298
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
243299
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
300+
break;
301+
case 6:
302+
ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
303+
src0, src1, src2, src3, src4, src5, src6, dst,
304+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
305+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
306+
break;
307+
case 7:
308+
ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
309+
src0, src1, src2, src3, src4, src5, src6, dst,
310+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
311+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
312+
break;
313+
case 8:
314+
ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
315+
src0, src1, src2, src3, src4, src5, src6, dst,
316+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
317+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
318+
break;
319+
default:
320+
ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
321+
src0, src1, src2, src3, src4, src5, src6, dst,
322+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
323+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
324+
break;
325+
}
244326
} else {
245327
GGML_ABORT("doesn't support d_state!=16.");
246328
}

0 commit comments

Comments
 (0)