Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,13 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat

data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, extra_context);
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
auto supports_thinking = tmpl.source().find("<think>") != std::string::npos;

// you should not be able to call enable_thinking if <think> is not supported
if (!supports_thinking && extra_context["enable_thinking"]) {
extra_context["enable_thinking"] = false;
}

if (string_ends_with(data.prompt, "<think>\n")) {
if (!extra_context["enable_thinking"]) {
data.prompt += "</think>";
Expand Down Expand Up @@ -1820,9 +1827,31 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
tool_call_alts.push_back(
"( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
builder.add_rule("root",
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));

builder.add_rule("thinking-start", "\"<think>\"");
builder.add_rule("thinking-content", "( [^<] | \"<\" [^/] | \"</\" [^t] | \"</t\" [^h] | \"</th\" [^i] | \"</thi\" [^n] | \"</thin\" [^k] | \"</think\" [^>] )*");
builder.add_rule("thinking-end", "\"</think>\" space");

//thinking grammar logic depending on if thinking_forced_open was to true (so already opened (and maybe closed)) and if thinking is even allowed
std::string thinking_grammar_logic = ""; // thinking tag was closed or not supported/wanted
if (extra_context["enable_thinking"]) {
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
data.thinking_forced_open ? "</think>" : "<think>"
});
if (data.thinking_forced_open) {
//thinking tag was already opened by used so we don't need to add it again
thinking_grammar_logic = "(thinking-content thinking-end) ";
}
else
{
thinking_grammar_logic = "(thinking-start thinking-content thinking-end) ";
}
}


builder.add_rule("root", thinking_grammar_logic + (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));

// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
Expand Down
2 changes: 2 additions & 0 deletions docs/function-calling.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Function calling is supported for all models (see https://github.com/ggml-org/ll
- Use `--chat-template-file` to override the template when appropriate (see examples below)
- Generic support may consume more tokens and be less efficient than a model's native format.

- Multiple/parallel tool calling is supported on some models but disabled by default, enable it by passing `"parallel_tool_calls": true` in the completion endpoint payload.

<details>
<summary>Show some common templates and which format handler they use</summary>

Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cpu/ggml-cpu-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
/**
* @see https://github.com/ggml-org/llama.cpp/pull/14037
*/
inline float vec_hsum(float32x4_t v) {
inline static float vec_hsum(float32x4_t v) {
float32x4_t v_temp = v + vec_reve(v);
return v_temp[0] + v_temp[1];
}
Expand Down
171 changes: 171 additions & 0 deletions ggml/src/ggml-cuda/conv2d.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#include "conv2d.cuh"

struct conv_params {
const int64_t IW, IH;
const int64_t OW, OH;
const int64_t KW, KH;
const int64_t ST_X, ST_Y;
const int64_t PD_X, PD_Y;
const int64_t DL_X, DL_Y;
const int64_t IC, OC;
const int64_t B;
const int64_t TOTAL;
};

struct kernel_bounds {
int64_t y_min, y_max;
int64_t x_min, x_max;
};

__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
return (a > b) ? a : b;
}

__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
return (a < b) ? a : b;
}

__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
kernel_bounds bounds;
bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
return bounds;
}

__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
int64_t kern_coord,
int64_t stride,
int64_t dilation,
int64_t padding) {
return out_coord * stride + kern_coord * dilation - padding;
}

struct whcn_layout {
__device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
}

__device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
}

__device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
}

__device__ static void unpack_indices(int64_t global_idx,
const conv_params & P,
int64_t & n,
int64_t & c,
int64_t & out_y,
int64_t & out_x) {
out_x = global_idx % P.OW;
out_y = (global_idx / P.OW) % P.OH;
c = (global_idx / (P.OW * P.OH)) % P.OC;
n = global_idx / (P.OW * P.OH * P.OC);
}
};

template <typename T, typename Layout>
static __global__ void conv2d_kernel(const float * __restrict__ input,
const T * __restrict__ kernel,
float * __restrict__ output,
const conv_params P) {
const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (global_idx >= P.TOTAL) {
return;
}

int64_t n, c_out, out_y, out_x;
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);

T acc = 0;

for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);

for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);

for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);

T input_val;
if (std::is_same<T, half>::value) {
input_val = __float2half(input[Layout::input_index(n, c_in, in_y, in_x, P)]);
} else {
input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
}

T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
acc += (input_val * kernel_val);
}
}
}

// [N, OC, OH, OW]
output[Layout::output_index(n, c_out, out_y, out_x, P)] = (float) acc;
}

template <typename T>
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
}

static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
}

static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<float>(X_D, K_D, Y_D, P, st);
}

void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
float * K_D = (float *) kernel->data;
const float * X_D = (const float *) input->data;
float * Y_D = (float *) dst->data;

GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);

// same number of input channels
GGML_ASSERT(input->ne[2] == kernel->ne[2]);

cudaStream_t st = ctx.stream();

const int32_t * p = (const int32_t *) dst->op_params;
const int ST_X = p[0]; // stride_x
const int ST_Y = p[1]; // stride_y
const int PD_X = p[2]; // padding_x
const int PD_Y = p[3]; // padding_y
const int DL_X = p[4]; // dilation_x
const int DL_Y = p[5]; // dilation_y

// No cwhn
GGML_ASSERT(p[6] == false);

const int IW = input->ne[0]; // input_w
const int IH = input->ne[1]; // input_h
const int OW = dst->ne[0]; // output_w
const int OH = dst->ne[1]; // output_h
const int KW = kernel->ne[0]; // kernel_w
const int KH = kernel->ne[1]; // kernel_h
const int IC = input->ne[2]; // input_channels
const int OC = kernel->ne[3]; // ouptut_chanles
const int B = input->ne[3]; // n_batches

const int64_t total = B * OC * OH * OW;
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };

if (kernel->type == GGML_TYPE_F16) {
conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
} else {
conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/conv2d.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once
#include "common.cuh"

#define CUDA_CONV2D_BLOCK_SIZE 256
void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh"
Expand Down Expand Up @@ -2451,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
case GGML_OP_CONV_2D:
ggml_cuda_op_conv2d(ctx, dst);
break;
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
Expand Down Expand Up @@ -3501,6 +3505,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
}
case GGML_OP_IM2COL:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
Expand Down
2 changes: 2 additions & 0 deletions tools/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,8 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type":

`parse_tool_calls`: Whether to parse the generated tool call.

`parallel_tool_calls` : Whether to enable parallel/multiple tool calls (only supported on some models, verification is based on jinja template).

*Examples:*

You can use either Python `openai` library with appropriate checkpoints:
Expand Down
Loading