Skip to content

Commit 5f12ad7

Browse files
authored
Merge branch 'ggml-org:master' into mradermacher
2 parents d8861ed + bb16041 commit 5f12ad7

32 files changed

+722
-378
lines changed

common/common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,9 @@ std::vector<llama_token> common_tokenize(
12941294
int n_tokens = text.length() + 2 * add_special;
12951295
std::vector<llama_token> result(n_tokens);
12961296
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
1297+
if (n_tokens == std::numeric_limits<int32_t>::min()) {
1298+
throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit");
1299+
}
12971300
if (n_tokens < 0) {
12981301
result.resize(-n_tokens);
12991302
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);

convert_hf_to_gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2193,7 +2193,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
21932193
name += ".weight"
21942194
if "multi_modal_projector.linear_1" in name:
21952195
# despite the name with number postfix, this is a single fully connected layer
2196-
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC], data_torch)]
2196+
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC] + '.weight', data_torch)]
21972197
return [(self.map_tensor_name(name), data_torch)]
21982198
return []
21992199

docs/build.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Build llama.cpp locally
22

3-
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h).
3+
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](../include/llama.h).
44

55
The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server.
66

ggml/include/ggml.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ extern "C" {
489489
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492+
GGML_OP_ROLL,
492493
GGML_OP_ARANGE,
493494
GGML_OP_TIMESTEP_EMBEDDING,
494495
GGML_OP_ARGSORT,
@@ -1801,6 +1802,17 @@ extern "C" {
18011802
int p0,
18021803
int p1);
18031804

1805+
// Move tensor elements by an offset given for each dimension. Elements that
1806+
// are shifted beyond the last position are wrapped around to the beginning.
1807+
GGML_API struct ggml_tensor * ggml_roll(
1808+
struct ggml_context * ctx,
1809+
struct ggml_tensor * a,
1810+
int shift0,
1811+
int shift1,
1812+
int shift2,
1813+
int shift3);
1814+
1815+
18041816
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
18051817
// timesteps: [N,]
18061818
// return: [N, dim]

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,6 +1890,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
18901890
{
18911891
ggml_compute_forward_pad_reflect_1d(params, tensor);
18921892
} break;
1893+
case GGML_OP_ROLL:
1894+
{
1895+
ggml_compute_forward_roll(params, tensor);
1896+
} break;
18931897
case GGML_OP_ARANGE:
18941898
{
18951899
ggml_compute_forward_arange(params, tensor);
@@ -2214,6 +2218,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22142218
case GGML_OP_UPSCALE:
22152219
case GGML_OP_PAD:
22162220
case GGML_OP_PAD_REFLECT_1D:
2221+
case GGML_OP_ROLL:
22172222
case GGML_OP_ARANGE:
22182223
case GGML_OP_TIMESTEP_EMBEDDING:
22192224
case GGML_OP_ARGSORT:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6793,6 +6793,73 @@ void ggml_compute_forward_pad_reflect_1d(
67936793
}
67946794
}
67956795

6796+
// ggml_compute_forward_roll
6797+
6798+
static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
6799+
if (i < 0) {
6800+
return i + ne;
6801+
} else if (i >= ne) {
6802+
return i - ne;
6803+
}
6804+
return i;
6805+
}
6806+
6807+
static void ggml_compute_forward_roll_f32(
6808+
const ggml_compute_params * params,
6809+
ggml_tensor * dst) {
6810+
6811+
const ggml_tensor * src0 = dst->src[0];
6812+
const float * src_data = (const float *) src0->data;
6813+
float * dst_data = (float *) dst->data;
6814+
6815+
GGML_TENSOR_UNARY_OP_LOCALS
6816+
6817+
const int s0 = ggml_get_op_params_i32(dst, 0);
6818+
const int s1 = ggml_get_op_params_i32(dst, 1);
6819+
const int s2 = ggml_get_op_params_i32(dst, 2);
6820+
const int s3 = ggml_get_op_params_i32(dst, 3);
6821+
6822+
const int64_t total = ne1 * ne2 * ne3;
6823+
const int64_t per_thread = (total + params->nth) / params->nth;
6824+
const int64_t start = params->ith * per_thread;
6825+
const int64_t end = std::min(start + per_thread, total);
6826+
6827+
for (int64_t i = start; i < end; ++i) {
6828+
const int64_t i1 = i % ne1;
6829+
const int64_t i2 = (i / ne1) % ne2;
6830+
const int64_t i3 = i / (ne2 * ne1);
6831+
float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
6832+
6833+
const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
6834+
const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
6835+
const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
6836+
const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
6837+
6838+
const int64_t s = ggml_wrap_index(-s0, ne00);
6839+
const int64_t n = ne00 - s;
6840+
ggml_vec_cpy_f32(n, dst_row, src_row + s);
6841+
ggml_vec_cpy_f32(s, dst_row + n, src_row);
6842+
}
6843+
}
6844+
6845+
void ggml_compute_forward_roll(
6846+
const ggml_compute_params * params,
6847+
ggml_tensor * dst) {
6848+
6849+
const ggml_tensor * src0 = dst->src[0];
6850+
6851+
switch (src0->type) {
6852+
case GGML_TYPE_F32:
6853+
{
6854+
ggml_compute_forward_roll_f32(params, dst);
6855+
} break;
6856+
default:
6857+
{
6858+
GGML_ABORT("fatal error");
6859+
}
6860+
}
6861+
}
6862+
67966863
// ggml_compute_forward_arange
67976864

67986865
static void ggml_compute_forward_arange_f32(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params
7272
void ggml_compute_forward_upscale(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7373
void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7474
void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
75+
void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7576
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7677
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7778
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#include <algorithm>
2+
3+
#include "conv2d-transpose.cuh"
4+
#include "ggml.h"
5+
6+
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
7+
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
8+
const int out_h, const int kernel_w, const int kernel_h, const int stride,
9+
const int c_in, const int c_out, const int batches) {
10+
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
11+
12+
const int total_elements = out_w * out_h * c_out * batches;
13+
14+
if (global_idx >= total_elements) {
15+
return;
16+
}
17+
18+
const int out_x_idx = global_idx % out_w;
19+
const int out_y_idx = (global_idx / out_w) % out_h;
20+
const int c_idx = (global_idx / (out_w * out_h)) % c_out;
21+
const int n_idx = global_idx / (out_w * out_h * c_out);
22+
23+
float accumulator = 0;
24+
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
25+
26+
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
27+
for (int kh = 0; kh < kernel_h; ++kh) {
28+
int in_y = out_y_idx - kh;
29+
if (in_y < 0 || in_y % stride) continue;
30+
in_y /= stride;
31+
if (in_y >= in_h) continue;
32+
33+
for (int kw = 0; kw < kernel_w; ++kw) {
34+
int in_x = out_x_idx - kw;
35+
if (in_x < 0 || in_x % stride) continue;
36+
in_x /= stride;
37+
if (in_x >= in_w) continue;
38+
39+
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
40+
const int kernel_idx =
41+
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
42+
43+
float input_val = input[input_idx];
44+
half kern_val = kernel[kernel_idx];
45+
46+
accumulator += input_val * (float) kern_val;
47+
}
48+
}
49+
}
50+
51+
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
52+
}
53+
54+
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
55+
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
56+
const ggml_tensor * kernel = dst->src[0];
57+
const ggml_tensor * input = dst->src[1];
58+
59+
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
60+
61+
const float * input_data = (const float *) input->data;
62+
float * output_data = (float *) dst->data;
63+
const half * kernel_data = (const half *) kernel->data;
64+
65+
const int input_w = input->ne[0];
66+
const int input_h = input->ne[1];
67+
const int output_w = dst->ne[0];
68+
const int output_h = dst->ne[1];
69+
const int channels_in = input->ne[2];
70+
const int channels_out = kernel->ne[2];
71+
const int kernel_w = kernel->ne[0];
72+
const int kernel_h = kernel->ne[1];
73+
const int stride = dst->op_params[0];
74+
const int batches = input->ne[3];
75+
76+
GGML_ASSERT(channels_in == kernel->ne[3]);
77+
GGML_ASSERT(stride > 0);
78+
79+
cudaStream_t st = ctx.stream();
80+
81+
GGML_ASSERT(ggml_is_contiguous(input));
82+
GGML_ASSERT(ggml_is_contiguous(kernel));
83+
GGML_ASSERT(ggml_is_contiguous(dst));
84+
85+
const int total = (output_w * output_h * channels_out * batches);
86+
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
87+
88+
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
89+
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
90+
channels_in, channels_out, batches);
91+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256
4+
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "ggml-cuda/concat.cuh"
1313
#include "ggml-cuda/conv-transpose-1d.cuh"
1414
#include "ggml-cuda/conv2d-dw.cuh"
15+
#include "ggml-cuda/conv2d-transpose.cuh"
1516
#include "ggml-cuda/convert.cuh"
1617
#include "ggml-cuda/count-equal.cuh"
1718
#include "ggml-cuda/cpy.cuh"
@@ -2356,6 +2357,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23562357
case GGML_OP_CONV_2D_DW:
23572358
ggml_cuda_op_conv2d_dw(ctx, dst);
23582359
break;
2360+
case GGML_OP_CONV_TRANSPOSE_2D:
2361+
ggml_cuda_conv_2d_transpose_p0(ctx, dst);
2362+
break;
23592363
case GGML_OP_CONV_TRANSPOSE_1D:
23602364
ggml_cuda_op_conv_transpose_1d(ctx,dst);
23612365
break;
@@ -3267,6 +3271,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32673271
}
32683272
case GGML_OP_IM2COL:
32693273
case GGML_OP_CONV_2D_DW:
3274+
case GGML_OP_CONV_TRANSPOSE_2D:
32703275
case GGML_OP_POOL_2D:
32713276
case GGML_OP_SUM:
32723277
case GGML_OP_SUM_ROWS:

0 commit comments

Comments
 (0)