diff --git a/.clang-format b/.clang-format
index 45232b80ed8cd..47d96b6b40983 100644
--- a/.clang-format
+++ b/.clang-format
@@ -22,8 +22,8 @@ AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Inline
AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: true
-BinPackArguments: true
-BinPackParameters: true # OnePerLine
+BinPackArguments: false
+BinPackParameters: false # OnePerLine
BitFieldColonSpacing: Both
BreakBeforeBraces: Custom # Attach
BraceWrapping:
@@ -70,15 +70,18 @@ ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
IncludeBlocks: Regroup
IncludeCategories:
- - Regex: '^<.*\.h>'
+ - Regex: '".*"'
Priority: 1
SortPriority: 0
- - Regex: '^<.*'
+ - Regex: '^<.*\.h>'
Priority: 2
SortPriority: 0
- - Regex: '.*'
+ - Regex: '^<.*'
Priority: 3
SortPriority: 0
+ - Regex: '.*'
+ Priority: 4
+ SortPriority: 0
IncludeIsMainRegex: '([-_](test|unittest))?$'
IncludeIsMainSourceRegex: ''
IndentAccessModifiers: false
diff --git a/CODEOWNERS b/CODEOWNERS
index 3186f8eb1c514..4c0dd4b725dd1 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -9,3 +9,4 @@
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
/ggml/src/ggml-opt.cpp @JohannesGaessler
/ggml/src/gguf.cpp @JohannesGaessler
+/ggml/src/ggml-vulkan/ @0cc4m
diff --git a/README.md b/README.md
index edde61238cb5f..9b2e0f851c9d7 100644
--- a/README.md
+++ b/README.md
@@ -270,7 +270,6 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
| [CANN](docs/build.md#cann) | Ascend NPU |
| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU |
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
-
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
## Obtaining and quantizing models
@@ -436,7 +435,7 @@ To learn more about model quantization, [read this documentation](tools/quantize
## [`llama-perplexity`](tools/perplexity)
-#### A tool for measuring the perplexity [^1][^2] (and other quality metrics) of a model over a given text.
+#### A tool for measuring the [perplexity](tools/perplexity/README.md) [^1] (and other quality metrics) of a model over a given text.
-
Measure the perplexity over a text file
@@ -459,8 +458,7 @@ To learn more about model quantization, [read this documentation](tools/quantize
-[^1]: [tools/perplexity/README.md](./tools/perplexity/README.md)
-[^2]: [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity)
+[^1]: [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity)
## [`llama-bench`](tools/llama-bench)
diff --git a/common/arg.cpp b/common/arg.cpp
index c1151f51da17b..80f965cc731f2 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -1612,7 +1612,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.antiprompt.emplace_back(value);
}
- ).set_examples({LLAMA_EXAMPLE_MAIN}));
+ ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-sp", "--special"},
string_format("special tokens output enabled (default: %s)", params.special ? "true" : "false"),
diff --git a/docs/build.md b/docs/build.md
index 50dbba486acf6..849c8252694fa 100644
--- a/docs/build.md
+++ b/docs/build.md
@@ -387,12 +387,12 @@ docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/ren
### For Linux users:
-First, follow the the official [Getting Started with the Linux Tarball Vulkan SDK](https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html) guide.
+First, follow the official LunarG instructions for the installation and setup of the Vulkan SDK in the [Getting Started with the Linux Tarball Vulkan SDK](https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html) guide.
> [!IMPORTANT]
> After completing the first step, ensure that you have used the `source` command on the `setup_env.sh` file inside of the Vulkan SDK in your current terminal session. Otherwise, the build won't work. Additionally, if you close out of your terminal, you must perform this step again if you intend to perform a build. However, there are ways to make this persistent. Refer to the Vulkan SDK guide linked in the first step for more information about any of this.
-Second, after verifying that you have done everything in the Vulkan SDK guide provided in the first step, run the following command to verify that everything is set up correctly:
+Second, after verifying that you have followed all of the SDK installation/setup steps, use this command to make sure before proceeding:
```bash
vulkaninfo
```
@@ -403,10 +403,11 @@ cmake -B build -DGGML_VULKAN=1
cmake --build build --config Release
```
-Finally, after finishing your build, you should be able to do this:
+Finally, after finishing your build, you should be able to do something like this:
```bash
-# Test the output binary (with "-ngl 33" to offload all layers to GPU)
-./build/bin/llama-cli -m "PATH_TO_MODEL" -p "Hi you how are you" -n 50 -e -ngl 33 -t 4
+# Test the output binary
+# "-ngl 99" should offload all of the layers to GPU for most (if not all) models.
+./build/bin/llama-cli -m "PATH_TO_MODEL" -p "Hi you how are you" -ngl 99
# You should see in the output, ggml_vulkan detected your GPU. For example:
# ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32
diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt
index 66a5ad8d2eddc..d9590b9d0bab8 100644
--- a/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/ggml/src/ggml-cpu/CMakeLists.txt
@@ -494,9 +494,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# Fetch KleidiAI sources:
include(FetchContent)
- set(KLEIDIAI_COMMIT_TAG "v1.9.0")
+ set(KLEIDIAI_COMMIT_TAG "v1.11.0")
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
- set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017")
+ set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp
index 910fd0ee4e743..ddd29d002d1ca 100644
--- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp
+++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp
@@ -22,9 +22,94 @@
#include "kai_common.h"
+#include "simd-mappings.h"
+
#include "kernels.h"
#define NELEMS(x) sizeof(x) / sizeof(*x)
+
+static const size_t INT4_PER_BYTE = 2;
+static const size_t INT4_BITS = 4;
+static const int Q4_0_ZERO_POINT = 8;
+const size_t INT4_PER_UINT16 = 4;
+
+static void dequantize_row_qsi4c32pscalef16(
+ const void *packed_data,
+ int32_t row_idx,
+ int64_t nc,
+ float *out,
+ size_t nr_pack,
+ size_t packed_row_stride,
+ size_t kr,
+ size_t bl,
+ size_t num_bytes_multiplier
+) {
+ size_t group_idx = row_idx / nr_pack;
+ size_t row_in_group = row_idx % nr_pack;
+ const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
+ size_t num_blocks = nc / bl;
+ const uint8_t *block_ptr = packed_group;
+
+ for (size_t b = 0; b < num_blocks; ++b) {
+ uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
+ float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
+
+ const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
+ size_t num_segments = bl / kr;
+ size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
+
+ for (size_t s = 0; s < num_segments; ++s) {
+ const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
+ const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
+ for (size_t k = 0; k < num_bytes_per_segment; ++k) {
+ uint8_t byte = qbytes[k] ^ 0x88;
+ int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
+ int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
+ out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
+ out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
+ }
+ }
+ block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
+ }
+}
+
+static void dequantize_row_qsi4c32ps1s0scalef16(
+ const void *packed_data,
+ int32_t row_idx,
+ int64_t k,
+ float *out,
+ size_t nr,
+ size_t packed_row_stride,
+ size_t kr,
+ size_t bl,
+ size_t num_bytes_multiplier
+) {
+ const size_t num_blocks = k / bl;
+ const size_t bl4 = bl / INT4_PER_UINT16;
+
+ size_t group_idx = row_idx / nr;
+ size_t row_in_group = row_idx % nr;
+
+ const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
+ const uint16_t *qdata = (const uint16_t *)packed_group;
+ const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
+
+ for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
+ uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
+ float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
+
+ for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
+ uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
+
+ for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
+ int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
+ out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
+ }
+ }
+ }
+ GGML_UNUSED(kr);
+}
+
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
#if defined(__ARM_FEATURE_SME)
{
@@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
},
/* .rhs_info = */ {
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+ /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
},
/* .rhs_info = */ {
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
- /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
+ /* .packed_stride = */ NULL,
+ /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
+ /* .to_float = */ NULL,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h
index 3b268d4a22aca..bc8f33405d1fe 100644
--- a/ggml/src/ggml-cpu/kleidiai/kernels.h
+++ b/ggml/src/ggml-cpu/kleidiai/kernels.h
@@ -71,12 +71,15 @@ struct rhs_packing_info {
std::function,
std::function
> packed_size;
+ size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
std::variant<
std::function,
std::function
> pack_func;
+ void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
+ size_t kr, size_t bl, size_t num_bytes_multiplier);
};
struct ggml_kleidiai_kernels {
diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
index fafe45e6c5c51..3a513a55d7654 100644
--- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
+++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
ggml_kleidiai_kernels * kernels;
} static ctx = { CPU_FEATURE_NONE, NULL };
+static const char* cpu_feature_to_string(cpu_feature f) {
+ switch (f) {
+ case CPU_FEATURE_NONE: return "NONE";
+ case CPU_FEATURE_DOTPROD: return "DOTPROD";
+ case CPU_FEATURE_I8MM: return "I8MM";
+ case CPU_FEATURE_SVE: return "SVE";
+ case CPU_FEATURE_SME: return "SME";
+ default: return "UNKNOWN";
+ }
+}
+
static void init_kleidiai_context(void) {
ggml_critical_section_start();
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
}
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
+#ifndef NDEBUG
+ if (ctx.kernels) {
+ GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
+ }
+#endif
}
ggml_critical_section_end();
}
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
class tensor_traits : public ggml::cpu::tensor_traits {
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
+ if (op->op != GGML_OP_MUL_MAT) {
+ return false;
+ }
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
GGML_ASSERT(kernels);
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
} else if (dst->src[0]->type == GGML_TYPE_F16) {
return compute_forward_kv_cache(params, dst);
}
+ } else if (dst->op == GGML_OP_GET_ROWS) {
+ if (dst->src[0]->type == GGML_TYPE_Q4_0) {
+ return compute_forward_get_rows(params, dst);
+ }
}
return false;
}
@@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
}
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
+
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@@ -342,8 +367,49 @@ class tensor_traits : public ggml::cpu::tensor_traits {
return true;
}
+ bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
+ GGML_ASSERT(ctx.kernels);
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
+ kernel_info * kernel = &ctx.kernels->gemm;
+
+ const int64_t nc = ne00;
+ const int64_t nr = ggml_nelements(src1);
+
+ const size_t block_rows = kernel->get_nr();
+ const size_t kr = kernel->get_kr();
+
+ const size_t num_bytes_multiplier = sizeof(uint16_t);
+ const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int dr = (nr + nth - 1) / nth;
+ const int ir0 = dr * ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int64_t i = ir0; i < ir1; ++i) {
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
+ int64_t row_idx = ((const int32_t *)src1->data)[i];
+ GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
+
+ float *out = (float *)((char *)dst->data + i * nb1);
+ rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
+ }
+
+ return true;
+ }
+
public:
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
+ GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
GGML_ASSERT(ctx.kernels);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
@@ -351,17 +417,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
size_t kr = ctx.kernels->gemm.get_kr();
size_t sr = ctx.kernels->gemm.get_sr();
-#ifndef NDEBUG
- const size_t repacked_size = variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
- GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
-#endif
struct kai_rhs_pack_qs4cxs1s0_param params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
return 0;
-
GGML_UNUSED(data_size);
}
};
@@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
- GGML_UNUSED(buffer);
return GGML_STATUS_SUCCESS;
+ GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
@@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
GGML_UNUSED(buft);
}
+static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
+ GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
+ GGML_ASSERT(ctx.kernels);
+
+ const size_t n = tensor->ne[1];
+ const size_t k = tensor->ne[0];
+ const size_t nr = ctx.kernels->gemm.get_nr();
+ const size_t kr = ctx.kernels->gemm.get_kr();
+
+ return variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
+
+ GGML_UNUSED(buft);
+}
+
namespace ggml::cpu::kleidiai {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
- if (op->op == GGML_OP_MUL_MAT &&
+ if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
op->src[0]->type == GGML_TYPE_Q4_0 &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
+ if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
+ return false;
+ }
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
- if (op->src[1]->type == GGML_TYPE_F32 &&
+ if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
return true;
}
@@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
}
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
- if (op->op == GGML_OP_MUL_MAT) {
+ if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
@@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
- /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
+ /* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
/* .is_host = */ nullptr,
},
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt
index c9ff4aa321b8b..98ed29bc9c12f 100644
--- a/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ggml/src/ggml-cuda/CMakeLists.txt
@@ -102,12 +102,12 @@ if (CUDAToolkit_FOUND)
if (GGML_STATIC)
if (WIN32)
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
- target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
else ()
- target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
endif()
else()
- target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt)
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
endif()
if (GGML_CUDA_NO_VMM)
diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu
index 86a54e42bb7e6..5bb85b4807bcf 100644
--- a/ggml/src/ggml-cuda/im2col.cu
+++ b/ggml/src/ggml-cuda/im2col.cu
@@ -10,7 +10,7 @@ static __global__ void im2col_kernel(
return;
}
- const int64_t ksize = OW * (KH > 1 ? KW : 1);
+ const int64_t ksize = OW * KH;
const int64_t kx = i / ksize;
const int64_t kd = kx * ksize;
const int64_t ky = (i - kd) / OW;
diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt
index ec5d8cf59556b..015fa8f06824e 100644
--- a/ggml/src/ggml-opencl/CMakeLists.txt
+++ b/ggml/src/ggml-opencl/CMakeLists.txt
@@ -105,6 +105,8 @@ set(GGML_OPENCL_KERNELS
pad
repeat
mul_mat_f16_f32
+ conv2d
+ conv2d_f16_f32
)
foreach (K ${GGML_OPENCL_KERNELS})
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
index 3388259152b46..a31483b61085a 100644
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
@@ -390,6 +390,9 @@ struct ggml_backend_opencl_context {
cl_program program_tanh;
cl_program program_upscale;
cl_program program_concat;
+ cl_program program_conv_2d_f16;
+ cl_program program_conv_2d_f32;
+ cl_program program_conv_2d_f16_f32;
cl_program program_tsembd;
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
@@ -441,6 +444,9 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_upscale_bilinear;
cl_kernel kernel_concat_f32_contiguous;
cl_kernel kernel_concat_f32_non_contiguous;
+ cl_kernel kernel_conv_2d_f16;
+ cl_kernel kernel_conv_2d_f32;
+ cl_kernel kernel_conv_2d_f16_f32;
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
@@ -1478,6 +1484,47 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
+ // conv2d
+ {
+ #ifdef GGML_OPENCL_EMBED_KERNELS
+ const std::string kernel_src {
+ #include "conv2d.cl.h"
+ };
+ const std::string kernel_src_f16_f32 {
+ #include "conv2d_f16_f32.cl.h"
+ };
+ #else
+ const std::string kernel_src = read_file("conv2d.cl");
+ const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl");
+ #endif
+ if (!kernel_src.empty()) {
+ backend_ctx->program_conv_2d_f16 =
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str());
+ CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err));
+ GGML_LOG_CONT(".");
+ backend_ctx->program_conv_2d_f32 =
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+ CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err));
+ GGML_LOG_CONT(".");
+ } else {
+ GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
+ backend_ctx->program_conv_2d_f16 = nullptr;
+ backend_ctx->kernel_conv_2d_f16 = nullptr;
+ backend_ctx->program_conv_2d_f32 = nullptr;
+ backend_ctx->kernel_conv_2d_f32 = nullptr;
+ }
+ if (!kernel_src_f16_f32.empty()) {
+ backend_ctx->program_conv_2d_f16_f32 =
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts);
+ CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err));
+ GGML_LOG_CONT(".");
+ } else {
+ GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n");
+ backend_ctx->program_conv_2d_f16_f32 = nullptr;
+ backend_ctx->kernel_conv_2d_f16_f32 = nullptr;
+ }
+ }
+
// mul_mv_id_q4_0_f32_8x_flat
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2361,6 +2408,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
op->src[0]->ne[3] == 1 && op->ne[3] == 1;
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
+ case GGML_OP_CONV_2D:
+ return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
+ (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
+ (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
case GGML_OP_CONCAT:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -4998,6 +5049,83 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
}
+static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_TENSOR_BINARY_OP_LOCALS;
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+ const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13;
+ const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1;
+
+ const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1];
+ const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];
+ const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];
+
+ const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type);
+ const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type);
+ const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type);
+
+ const int64_t NPQ = (int64_t)N * OW * OH;
+
+ const uint32_t BS_K = 64;
+ const uint32_t BS_NPQ = 64;
+ const uint32_t BS_CRS = 16;
+ const uint32_t VEC_SIZE = 4;
+
+ const uint32_t TS_K = 4;
+ const uint32_t TS_NPQ = 8;
+
+ const uint32_t WG_K = BS_K / TS_K;
+ const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;
+
+ auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };
+ const uint32_t NB_K = splitWork(Cout, BS_K);
+ const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);
+
+ cl_kernel kernel;
+ size_t shmem_size;
+
+ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+ kernel = backend_ctx->kernel_conv_2d_f16;
+ shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4));
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+ kernel = backend_ctx->kernel_conv_2d_f32;
+ shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+ kernel = backend_ctx->kernel_conv_2d_f16_f32;
+ shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
+ } else {
+ GGML_ASSERT(false && "Unsupported data type combination for conv2d");
+ return;
+ }
+
+ cl_uint idx = 0;
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL));
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N));
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H));
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH));
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1));
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1));
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03));
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));
+
+ size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 };
+ size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 };
+
+ backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
+}
+
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -6752,6 +6880,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
ggml_cl_upscale(backend, tensor->src[0], tensor);
return true;
+ case GGML_OP_CONV_2D:
+ if (!any_on_device) {
+ return false;
+ }
+ func = ggml_cl_conv_2d;
+ break;
case GGML_OP_CONCAT:
if (!any_on_device) {
return false;
diff --git a/ggml/src/ggml-opencl/kernels/conv2d.cl b/ggml/src/ggml-opencl/kernels/conv2d.cl
new file mode 100644
index 0000000000000..e339c90cff59f
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/conv2d.cl
@@ -0,0 +1,185 @@
+#ifdef USE_FP16
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#define T_FLOAT half
+#define T_FLOAT4 half4
+#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p)
+#else
+#define T_FLOAT float
+#define T_FLOAT4 float4
+#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p)
+#endif
+
+#if defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+#define REQD_SUBGROUP_SIZE_128
+#endif
+
+#define T_ACCUM float4
+#define VEC_SIZE 4
+
+#define BS_K 64
+#define BS_NPQ 64
+#define BS_CRS 16
+
+#define TS_K 4
+#define TS_NPQ 8
+
+#define WG_K (BS_K / TS_K)
+#define WG_NPQ (BS_NPQ / TS_NPQ)
+
+#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
+#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
+
+static inline uint splitWork(uint work_size, uint block_size){
+ return (work_size + block_size - 1) / block_size;
+}
+
+REQD_SUBGROUP_SIZE_128
+kernel void kernel_conv_2d(
+ global void* p_knl,
+ ulong off_knl,
+ global void* p_src,
+ ulong off_src,
+ global void* p_dst,
+ ulong off_dst,
+ local void* shared,
+ uint Cout, uint Cin, uint N,
+ uint KW, uint KH, uint W, uint H, uint OW, uint OH,
+ uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
+ uint nb01, uint nb02, uint nb03,
+ uint nb11, uint nb12, uint nb13,
+ uint nb1, uint nb2, uint nb3
+) {
+ global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl);
+ global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src);
+ global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst);
+
+ const uint K = Cout;
+ const uint CRS = Cin*KH*KW;
+ const uint NPQ = N*OH*OW;
+
+ const uint lid_k = get_local_id(0);
+ const uint lid_npq = get_local_id(1);
+ const uint tid = lid_npq * WG_K + lid_k;
+
+ const uint B_idx_K = get_group_id(0);
+ const uint B_idx_NPQ = get_group_id(1);
+
+ const uint offset_k = B_idx_K * BS_K;
+ const uint offset_npq = B_idx_NPQ * BS_NPQ;
+
+ local T_FLOAT* Ash = (local T_FLOAT*)shared;
+ local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS];
+
+ T_ACCUM regC[TS_K][TS_NPQ_VEC];
+ for (int i = 0; i < TS_K; ++i) {
+ for (int j = 0; j < TS_NPQ_VEC; ++j) {
+ regC[i][j] = (T_ACCUM)(0.0f);
+ }
+ }
+
+ const uint NB_CRS = splitWork(CRS, BS_CRS);
+
+ for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
+ const uint offset_crs = B_idx_CRS * BS_CRS;
+
+ for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
+ const uint k_l = i / BS_CRS;
+ const uint crs_l = i % BS_CRS;
+ const uint k_g = offset_k + k_l;
+ const uint crs_g = offset_crs + crs_l;
+
+ if (k_g < K && crs_g < CRS) {
+ const uint Cin_idx = crs_g / (KW*KH);
+ const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
+ const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
+ const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
+ Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
+ } else {
+ Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f;
+ }
+ }
+
+ for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
+ const uint crs_l = i / BS_NPQ_VEC;
+ const uint npq_l_vec = i % BS_NPQ_VEC;
+ const uint crs_g = offset_crs + crs_l;
+
+ T_FLOAT4 val = (T_FLOAT4)(0.0f);
+ if (crs_g < CRS) {
+ const uint Cin_idx = crs_g / (KW * KH);
+ const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
+ const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
+ for (int v = 0; v < VEC_SIZE; ++v) {
+ const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
+ if (npq_g < NPQ) {
+ const uint N_idx = npq_g / (OH * OW);
+ const uint pq_idx = npq_g % (OH * OW);
+ const uint OH_idx = pq_idx / OW;
+ const uint OW_idx = pq_idx % OW;
+ const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
+ const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
+
+ if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
+ const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
+ ((T_FLOAT*)&val)[v] = src_data[src_idx];
+ }
+ }
+ }
+ }
+ Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
+ }
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ #pragma unroll
+ for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
+ T_FLOAT regA[TS_K];
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+ regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
+ }
+
+ for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+ T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+ regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]);
+ }
+ }
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+ }
+
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+ const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
+ if (k_g >= K) continue;
+
+ for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+ const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
+
+ const uint N_idx = npq_g_base / (OH * OW);
+ const uint pq_idx = npq_g_base % (OH * OW);
+ const uint OH_idx = pq_idx / OW;
+ const uint OW_idx = pq_idx % OW;
+
+ if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
+ const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
+ VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
+ } else {
+ T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
+ for (int v = 0; v < VEC_SIZE; ++v) {
+ const uint npq_g = npq_g_base + v;
+ if (npq_g < NPQ) {
+ const uint N_idx_s = npq_g / (OH*OW);
+ const uint pq_idx_s = npq_g % (OH*OW);
+ const uint OH_idx_s = pq_idx_s / OW;
+ const uint OW_idx_s = pq_idx_s % OW;
+ const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
+ dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl b/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl
new file mode 100644
index 0000000000000..cb05637f33ac8
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl
@@ -0,0 +1,176 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#if defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+#define REQD_SUBGROUP_SIZE_128
+#endif
+
+#define T_ACCUM float4
+#define VEC_SIZE 4
+
+#define BS_K 64
+#define BS_NPQ 64
+#define BS_CRS 16
+
+#define TS_K 4
+#define TS_NPQ 8
+
+#define WG_K (BS_K / TS_K)
+#define WG_NPQ (BS_NPQ / TS_NPQ)
+
+#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
+#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
+
+static inline uint splitWork(uint work_size, uint block_size){
+ return (work_size + block_size - 1) / block_size;
+}
+
+REQD_SUBGROUP_SIZE_128
+kernel void kernel_conv_2d(
+ global void* p_knl,
+ ulong off_knl,
+ global void* p_src,
+ ulong off_src,
+ global void* p_dst,
+ ulong off_dst,
+ local void* shared,
+ uint Cout, uint Cin, uint N,
+ uint KW, uint KH, uint W, uint H, uint OW, uint OH,
+ uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
+ uint nb01, uint nb02, uint nb03,
+ uint nb11, uint nb12, uint nb13,
+ uint nb1, uint nb2, uint nb3
+) {
+ global half* knl_data = (global half*) ((global char*)p_knl + off_knl);
+ global float* src_data = (global float*) ((global char*)p_src + off_src);
+ global float* dst_data = (global float*) ((global char*)p_dst + off_dst);
+
+ const uint K = Cout;
+ const uint CRS = Cin*KH*KW;
+ const uint NPQ = N*OH*OW;
+
+ const uint lid_k = get_local_id(0);
+ const uint lid_npq = get_local_id(1);
+ const uint tid = lid_npq * WG_K + lid_k;
+
+ const uint B_idx_K = get_group_id(0);
+ const uint B_idx_NPQ = get_group_id(1);
+
+ const uint offset_k = B_idx_K * BS_K;
+ const uint offset_npq = B_idx_NPQ * BS_NPQ;
+
+ local half* Ash = (local half*)shared;
+ local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS];
+
+ T_ACCUM regC[TS_K][TS_NPQ_VEC];
+ for (int i = 0; i < TS_K; ++i) {
+ for (int j = 0; j < TS_NPQ_VEC; ++j) {
+ regC[i][j] = (T_ACCUM)(0.0f);
+ }
+ }
+
+ const uint NB_CRS = splitWork(CRS, BS_CRS);
+
+ for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
+ const uint offset_crs = B_idx_CRS * BS_CRS;
+
+ for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
+ const uint k_l = i / BS_CRS;
+ const uint crs_l = i % BS_CRS;
+ const uint k_g = offset_k + k_l;
+ const uint crs_g = offset_crs + crs_l;
+
+ if (k_g < K && crs_g < CRS) {
+ const uint Cin_idx = crs_g / (KW*KH);
+ const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
+ const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
+ const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
+ Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
+ } else {
+ Ash[k_l * BS_CRS + crs_l] = (half)0.0f;
+ }
+ }
+
+ for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
+ const uint crs_l = i / BS_NPQ_VEC;
+ const uint npq_l_vec = i % BS_NPQ_VEC;
+ const uint crs_g = offset_crs + crs_l;
+
+ float4 val = (float4)(0.0f);
+ if (crs_g < CRS) {
+ const uint Cin_idx = crs_g / (KW * KH);
+ const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
+ const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
+ for (int v = 0; v < VEC_SIZE; ++v) {
+ const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
+ if (npq_g < NPQ) {
+ const uint N_idx = npq_g / (OH * OW);
+ const uint pq_idx = npq_g % (OH * OW);
+ const uint OH_idx = pq_idx / OW;
+ const uint OW_idx = pq_idx % OW;
+ const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
+ const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
+
+ if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
+ const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
+ ((float*)&val)[v] = src_data[src_idx];
+ }
+ }
+ }
+ }
+ Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
+ }
+
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ #pragma unroll
+ for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
+ half regA[TS_K];
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+ regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
+ }
+
+ for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+ float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+ regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]);
+ }
+ }
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+ }
+
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+ const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
+ if (k_g >= K) continue;
+
+ for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+ const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
+
+ const uint N_idx = npq_g_base / (OH * OW);
+ const uint pq_idx = npq_g_base % (OH * OW);
+ const uint OH_idx = pq_idx / OW;
+ const uint OW_idx = pq_idx % OW;
+
+ if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
+ const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
+ vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
+ } else {
+ T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
+ for (int v = 0; v < VEC_SIZE; ++v) {
+ const uint npq_g = npq_g_base + v;
+ if (npq_g < NPQ) {
+ const uint N_idx_s = npq_g / (OH*OW);
+ const uint pq_idx_s = npq_g % (OH*OW);
+ const uint OH_idx_s = pq_idx_s / OW;
+ const uint OW_idx_s = pq_idx_s % OW;
+ const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
+ dst_data[dst_idx_s] = ((float*)&res)[v];
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/ggml/src/ggml-opencl/kernels/im2col_f16.cl b/ggml/src/ggml-opencl/kernels/im2col_f16.cl
index b84c8984653c2..cf6cdaa4ce58c 100644
--- a/ggml/src/ggml-opencl/kernels/im2col_f16.cl
+++ b/ggml/src/ggml-opencl/kernels/im2col_f16.cl
@@ -31,7 +31,7 @@ kernel void kernel_im2col_f16(
src1 = (global float*)((global char*)src1 + offset1);
dst = (global half*)((global char*)dst + offsetd);
- long ksize = OW * (KH > 1 ? KW : 1);
+ long ksize = OW * KH;
long kx = i / ksize;
long kd = kx * ksize;
long ky = (i - kd) / OW;
diff --git a/ggml/src/ggml-opencl/kernels/im2col_f32.cl b/ggml/src/ggml-opencl/kernels/im2col_f32.cl
index 4bf65e4eaafba..1ecdb2344ad9d 100644
--- a/ggml/src/ggml-opencl/kernels/im2col_f32.cl
+++ b/ggml/src/ggml-opencl/kernels/im2col_f32.cl
@@ -31,7 +31,7 @@ kernel void kernel_im2col_f32(
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
- long ksize = OW * (KH > 1 ? KW : 1);
+ long ksize = OW * KH;
long kx = i / ksize;
long kd = kx * ksize;
long ky = (i - kd) / OW;
diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp
index 52737cc746dfa..7adcb3d9d9c76 100644
--- a/ggml/src/ggml-sycl/im2col.cpp
+++ b/ggml/src/ggml-sycl/im2col.cpp
@@ -26,7 +26,7 @@ static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_
// make each work-item deal with more elements since sycl global range can not exceed max int
for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) {
- const int64_t ksize = OW * (KH > 1 ? KW : 1);
+ const int64_t ksize = OW * KH;
const int64_t kx = i / ksize;
const int64_t kd = kx * ksize;
const int64_t ky = (i - kd) / OW;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
index 17c7ccb90d001..fdbcf7eba0fa5 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
@@ -40,12 +40,10 @@ void main() {
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1;
- const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
+ const uint ksize = p.OW * p.KH;
const uint base_linear_idx = gidx * NUM_ITER;
- const uint max_ky = ksize / p.OW;
-
uint current_kx = base_linear_idx / ksize;
const uint rem = base_linear_idx - (current_kx * ksize);
uint current_ky = rem / p.OW;
@@ -76,7 +74,7 @@ void main() {
if (++current_ix == p.OW) {
current_ix = 0;
- if (++current_ky == max_ky) {
+ if (++current_ky == p.KH) {
current_ky = 0;
current_kx++;
}
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 731b4980af947..a6d00542dd21e 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -5093,6 +5093,7 @@ static std::vector> make_test_cases_eval() {
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
// Conv_2D test cases
#ifdef DETAILED_TESTS
diff --git a/tools/main/main.cpp b/tools/main/main.cpp
index 516bf09652484..eb36c6884059c 100644
--- a/tools/main/main.cpp
+++ b/tools/main/main.cpp
@@ -785,14 +785,17 @@ int main(int argc, char ** argv) {
}
// check for reverse prompt using special tokens
- llama_token last_token = common_sampler_last(smpl);
- for (auto token : antiprompt_token) {
- if (token == last_token) {
- if (params.interactive) {
- is_interacting = true;
+ // avoid calling common_sampler_last() if last_output is empty
+ if (!last_output.empty()) {
+ llama_token last_token = common_sampler_last(smpl);
+ for (auto token : antiprompt_token) {
+ if (token == last_token) {
+ if (params.interactive) {
+ is_interacting = true;
+ }
+ is_antiprompt = true;
+ break;
}
- is_antiprompt = true;
- break;
}
}
diff --git a/tools/server/README.md b/tools/server/README.md
index e29511cb1b457..aa07f1ef5b177 100644
--- a/tools/server/README.md
+++ b/tools/server/README.md
@@ -575,6 +575,8 @@ These words will not be included in the completion, so make sure to add them to
`add_special`: (Optional) Boolean indicating if special tokens, i.e. `BOS`, should be inserted. Default: `false`
+`parse_special`: (Optional) Boolean indicating if special tokens should be tokenized. When `false` special tokens are treated as plaintext. Default: `true`
+
`with_pieces`: (Optional) Boolean indicating whether to return token pieces along with IDs. Default: `false`
**Response:**
diff --git a/tools/server/server.cpp b/tools/server/server.cpp
index 0afe213af1e47..022b5d0b31034 100644
--- a/tools/server/server.cpp
+++ b/tools/server/server.cpp
@@ -253,6 +253,7 @@ struct server_task {
defaults.sampling = params_base.sampling;
defaults.speculative = params_base.speculative;
defaults.n_keep = params_base.n_keep;
+ defaults.antiprompt = params_base.antiprompt;
// enabling this will output extra debug information in the HTTP responses from the server
params.verbose = params_base.verbosity > 9;
@@ -490,6 +491,10 @@ struct server_task {
}
}
}
+ // set reverse prompt from cli args if not set in the request
+ if (params.antiprompt.empty()) {
+ params.antiprompt = defaults.antiprompt;
+ }
}
{
@@ -4516,9 +4521,10 @@ int main(int argc, char ** argv) {
json tokens_response = json::array();
if (body.count("content") != 0) {
const bool add_special = json_value(body, "add_special", false);
+ const bool parse_special = json_value(body, "parse_special", true);
const bool with_pieces = json_value(body, "with_pieces", false);
- llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true);
+ llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special);
if (with_pieces) {
for (const auto& token : tokens) {