Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2869,6 +2869,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"(default: deepseek)",
[](common_params & params, const std::string & value) {
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
else { throw std::invalid_argument("invalid value"); }
}
Expand Down
15 changes: 8 additions & 7 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ json common_chat_msg::to_json_oaicompat() const

std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) {
std::vector<common_chat_msg_diff> diffs;
// if (previous_msg.reasoning_content != current.reasoning_content) {
// auto & diff = diffs.emplace_back();
// diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, current.reasoning_content);
// }
if (previous_msg.reasoning_content != new_msg.reasoning_content) {
auto & diff = diffs.emplace_back();
diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content);
}
if (previous_msg.content != new_msg.content) {
auto & diff = diffs.emplace_back();
diff.content_delta = string_diff(previous_msg.content, new_msg.content);
Expand Down Expand Up @@ -385,9 +385,9 @@ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & t

template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
json delta = json::object();
// if (!diff.reasoning_content_delta.empty()) {
// delta["reasoning_content"] = msg.reasoning_content;
// }
if (!diff.reasoning_content_delta.empty()) {
delta["reasoning_content"] = diff.reasoning_content_delta;
}
if (!diff.content_delta.empty()) {
delta["content"] = diff.content_delta;
}
Expand Down Expand Up @@ -598,6 +598,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
switch (format) {
case COMMON_REASONING_FORMAT_NONE: return "none";
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
default:
throw std::runtime_error("Unknown reasoning format");
}
Expand Down
2 changes: 1 addition & 1 deletion common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct common_chat_msg {
};

struct common_chat_msg_diff {
// std::string reasoning_content_delta;
std::string reasoning_content_delta;
std::string content_delta;
size_t tool_call_index = std::string::npos;
common_chat_tool_call tool_call_delta;
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ struct common_params_vocoder {

enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};

struct common_params {
Expand Down
3 changes: 2 additions & 1 deletion ggml/src/ggml-cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
endif()

string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
string(TOUPPER "${POWER10_M}" POWER10_M_UPPER)
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M_UPPER}")
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")

if (EXTRACTED_NUMBER GREATER_EQUAL 10)
Expand Down
8 changes: 5 additions & 3 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -4766,14 +4766,16 @@ static bool ggml_metal_encode_node(
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);

const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;

// 2*(2*ncpsg + nqptg)*(nsg)
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
//
// 16*32*(nsg)
// the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))

int64_t nsgmax = 2;

Expand Down Expand Up @@ -4810,9 +4812,9 @@ static bool ggml_metal_encode_node(
// and store the soft_max values and the mask
//
// ne00*(nsg)
// each simdgroup has a full f16 head vector in shared mem to accumulate results
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))

int64_t nsgmax = 2;
while (true) {
Expand Down
94 changes: 52 additions & 42 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext(
constexpr short NW = N_SIMDWIDTH;
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)

const short TS = nsg*SH; // shared memory size per query in (s_t == float)
const short T = DK + 2*TS; // shared memory size per query in (half)
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
const short T = 2*DK + 2*TS; // shared memory size per query in (half)

threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix

threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
Expand All @@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext(
if (iq1 + j < args.ne01) {
sq4[j*DK4 + i] = (q4_t) q4[i];
} else {
sq4[j*DK4 + i] = (q4_t) 0.0f;
sq4[j*DK4 + i] = 0;
}
}
}
Expand Down Expand Up @@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext(

// reduce the warps sequentially
for (ushort sg = 1; sg < nsg; ++sg) {
float S = { 0.0f };
float M = { -__FLT_MAX__/2 };

threadgroup_barrier(mem_flags::mem_threadgroup);

// each simdgroup stores its output to shared memory, reusing sq
Expand All @@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext(
const float M0 = ss[j*TS + 1];
const float M1 = ss[j*TS + sg*SH + 1];

M = max(M0, M1);
const float M = max(M0, M1);

const float ms0 = exp(M0 - M);
const float ms1 = exp(M1 - M);

S = S0*ms0 + S1*ms1;
const float S = S0*ms0 + S1*ms1;

if (tiisg == 0) {
ss[j*TS + 0] = S;
Expand Down Expand Up @@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext(
}
}

device float4 * dst4 = (device float4 *) dst;
threadgroup_barrier(mem_flags::mem_threadgroup);

threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);

// final rescale with 1/S and store to global memory
if (sgitg == 0) {
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
const float S = ss[j*TS + 0];
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
const float S = 1.0f/sf[j*TS + 0];

for (short i = tiisg; i < DV4; i += NW) {
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
}
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;

for (short i = tiisg; i < DV4; i += NW) {
dst4[i] = (float4) so4[j*DV4 + i]*S;
}
}
}
Expand All @@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext(
// template to be able to explore different combinations
//
#define FA_TYPES \
half, half4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
float, simdgroup_float8x8, \
float, simdgroup_float8x8, \
half, half4, simdgroup_half8x8
float, float4, simdgroup_float8x8, \
half, half4x4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
float, simdgroup_float8x8, \
float, simdgroup_float8x8, \
float, float4, simdgroup_float8x8
//half, half4, simdgroup_half8x8

#define FA_TYPES_BF \
bfloat, bfloat4, simdgroup_bfloat8x8, \
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
float, simdgroup_float8x8, \
float, simdgroup_float8x8, \
float, float4, simdgroup_float8x8
//half, half4, simdgroup_half8x8

typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;

Expand All @@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;

#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif

template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
Expand Down Expand Up @@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;

#undef FA_TYPES
#undef FA_TYPES_BF

template<
typename q4_t, // query types in shared memory
Expand Down Expand Up @@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec(

const short T = DK + nsg*SH; // shared memory size per query in (half)

//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results

// store the result for all queries in local memory (the O matrix from the paper)
o4_t lo[DV4/NL];
Expand Down Expand Up @@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec(
half4, \
float, \
float, float4, \
half4
float4

typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;

Expand Down
Loading
Loading