diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ec2ab5a58d027..8d6ba5f9f366f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1515,3 +1515,29 @@ jobs: run: | vulkaninfo --summary GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp + + ggml-ci-arm64-cpu-kleidiai: + runs-on: ubuntu-22.04-arm + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.16 + with: + key: ggml-ci-arm64-cpu-kleidiai + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install -y build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + run: | + GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + diff --git a/ci/run.sh b/ci/run.sh index f925956a84e84..bf0d53f20af56 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -22,6 +22,9 @@ # # with MUSA support # GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt # +# # with KLEIDIAI support +# GG_BUILD_KLEIDIAI=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt +# if [ -z "$2" ]; then echo "usage: $0 " @@ -115,6 +118,34 @@ if [ ! -z ${GG_BUILD_NO_SVE} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm" fi +if [ -n "${GG_BUILD_KLEIDIAI}" ]; then + echo ">>===== Enabling KleidiAI support" + + CANDIDATES=("armv9-a+dotprod+i8mm" "armv8.6-a+dotprod+i8mm" "armv8.2-a+dotprod") + CPU="" + + for cpu in "${CANDIDATES[@]}"; do + if echo 'int main(){}' | ${CXX:-c++} -march="$cpu" -x c++ - -c -o /dev/null >/dev/null 2>&1; then + CPU="$cpu" + break + fi + done + + if [ -z "$CPU" ]; then + echo "ERROR: None of the required ARM baselines (armv9/armv8.6/armv8.2 + dotprod) are supported by this compiler." + exit 1 + fi + + echo ">>===== Using ARM baseline: ${CPU}" + + CMAKE_EXTRA="${CMAKE_EXTRA:+$CMAKE_EXTRA } \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_KLEIDIAI=ON \ + -DGGML_CPU_AARCH64=ON \ + -DGGML_CPU_ARM_ARCH=${CPU} \ + -DBUILD_SHARED_LIBS=OFF" +fi + ## helpers # download a file if it does not exist or if it is outdated diff --git a/common/arg.cpp b/common/arg.cpp index 4204f6c6908fb..d17645cf2f395 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1935,6 +1935,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.n_ctx_checkpoints = value; } ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--cache-ram", "-cram"}, "N", + string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib), + [](common_params & params, int value) { + params.cache_ram_mib = value; + } + ).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--kv-unified", "-kvu"}, string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n" diff --git a/common/chat.h b/common/chat.h index a1afe574bd0ca..f7b36ec711df4 100644 --- a/common/chat.h +++ b/common/chat.h @@ -33,8 +33,8 @@ struct common_chat_msg_content_part { struct common_chat_msg { std::string role; std::string content; - std::vector content_parts = {}; - std::vector tool_calls = {}; + std::vector content_parts; + std::vector tool_calls; std::string reasoning_content; std::string tool_name; std::string tool_call_id; @@ -44,7 +44,7 @@ struct common_chat_msg { bool empty() const { return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); } - void ensure_tool_call_ids_set(std::vector & ids_cache, const std::function & gen_tool_call_id) { + void set_tool_call_ids(std::vector & ids_cache, const std::function & gen_tool_call_id) { for (auto i = 0u; i < tool_calls.size(); i++) { if (ids_cache.size() <= i) { auto id = tool_calls[i].id; diff --git a/common/common.h b/common/common.h index 0d3638c9c6228..040a44ebd89b0 100644 --- a/common/common.h +++ b/common/common.h @@ -378,7 +378,7 @@ struct common_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool no_perf = false; // disable performance metrics - bool ctx_shift = false; // context shift on infinite text generation + bool ctx_shift = false; // context shift on infinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool kv_unified = false; // enable unified KV cache @@ -425,7 +425,8 @@ struct common_params { int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting - int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot + int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot + int32_t cache_ram_mib = 8192; // 0 = no limit, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a59ebfc0da776..43d345bcb480c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -93,13 +93,15 @@ class ModelBase: # Mistral format specifics is_mistral_format: bool = False disable_mistral_community_chat_template: bool = False + sentence_transformers_dense_modules: bool = False def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, - disable_mistral_community_chat_template: bool = False): + disable_mistral_community_chat_template: bool = False, + sentence_transformers_dense_modules: bool = False): if type(self) is ModelBase or \ type(self) is TextModel or \ type(self) is MmprojModel: @@ -114,6 +116,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.lazy = not eager or (remote_hf_model_id is not None) self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id + self.sentence_transformers_dense_modules = sentence_transformers_dense_modules if remote_hf_model_id is not None: self.is_safetensors = True @@ -5269,6 +5272,53 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("Gemma3TextModel") class EmbeddingGemma(Gemma3Model): model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING + module_paths = [] + dense_features_dims = {} + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.sentence_transformers_dense_modules: + # read modules.json to determine if model has Dense layers + modules_file = self.dir_model / "modules.json" + if modules_file.is_file(): + with open(modules_file, encoding="utf-8") as modules_json_file: + mods = json.load(modules_json_file) + for mod in mods: + if mod["type"] == "sentence_transformers.models.Dense": + mod_path = mod["path"] + # check if model.safetensors file for Dense layer exists + model_tensors_file = self.dir_model / mod_path / "model.safetensors" + if model_tensors_file.is_file(): + self.module_paths.append(mod_path) + # read config.json of the Dense layer to get in/out features + mod_conf_file = self.dir_model / mod_path / "config.json" + if mod_conf_file.is_file(): + with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file: + mod_conf = json.load(mod_conf_json_file) + # hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights + prefix = self._get_dense_prefix(mod_path) + if mod_conf["in_features"] is not None and mod_conf["out_features"] is not None: + self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"]) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + from safetensors.torch import load_file + module_paths = list(self.module_paths) + for i, module_path in enumerate(module_paths): + tensors_file = self.dir_model / module_path / "model.safetensors" + local_tensors = load_file(tensors_file) + tensor_name = self._get_dense_prefix(module_path) + for name, local_tensor in local_tensors.items(): + if not name.endswith(".weight"): + continue + orig_name = name.replace("linear", tensor_name) + name = self.map_tensor_name(orig_name) + yield name, local_tensor.clone() + + @staticmethod + def _get_dense_prefix(module_path) -> str: + """Get the tensor name prefix for the Dense layer from module path.""" + tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3" + return tensor_name def set_gguf_parameters(self): super().set_gguf_parameters() @@ -5285,6 +5335,10 @@ def set_gguf_parameters(self): logger.info(f"Using original sliding_window from config: {orig_sliding_window} " f"instead of {self.hparams['sliding_window']}") self.gguf_writer.add_sliding_window(orig_sliding_window) + if self.sentence_transformers_dense_modules: + for dense, dims in self.dense_features_dims.items(): + logger.info(f"Setting dense layer {dense} in/out features to {dims}") + self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1]) self._try_set_pooling_type() @@ -9335,6 +9389,13 @@ def parse_args() -> argparse.Namespace: ) ) + parser.add_argument( + "--sentence-transformers-dense-modules", action="store_true", + help=("Whether to include sentence-transformers dense modules." + "It can be used for sentence-transformers models, like google/embeddinggemma-300m" + "Default these modules are not included.") + ) + args = parser.parse_args() if not args.print_supported_models and args.model is None: parser.error("the following arguments are required: model") @@ -9397,9 +9458,13 @@ def main() -> None: if args.remote: hf_repo_id = args.model from huggingface_hub import snapshot_download + allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"] + if args.sentence_transformers_dense_modules: + # include sentence-transformers dense modules safetensors files + allowed_patterns.append("*.safetensors") local_dir = snapshot_download( repo_id=hf_repo_id, - allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]) + allow_patterns=allowed_patterns) dir_model = Path(local_dir) logger.info(f"Downloaded config and tokenizer to {local_dir}") else: @@ -9467,7 +9532,8 @@ def main() -> None: split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, small_first_shard=args.no_tensor_first_split, - remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template + remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template, + sentence_transformers_dense_modules=args.sentence_transformers_dense_modules ) if args.vocab_only: diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile index f0867cfe46c3a..25b0514b29bc5 100644 --- a/examples/model-conversion/Makefile +++ b/examples/model-conversion/Makefile @@ -116,20 +116,39 @@ embedding-convert-model: METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ ./scripts/embedding/convert-model.sh +embedding-convert-model-st: + $(call validate_embedding_model_path,embedding-convert-model-st) + @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \ + METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ + ./scripts/embedding/convert-model.sh -st + embedding-run-original-model: $(call validate_embedding_model_path,embedding-run-original-model) @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \ + USE_SENTENCE_TRANSFORMERS="$(USE_SENTENCE_TRANSFORMERS)" \ ./scripts/embedding/run-original-model.py \ - $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \ + $(if $(USE_SENTENCE_TRANSFORMERS),--use-sentence-transformers) + +embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1 +embedding-run-original-model-st: embedding-run-original-model embedding-run-converted-model: @./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \ - $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \ + $(if $(USE_POOLING),--pooling) + +embedding-run-converted-model-st: USE_POOLING=1 +embedding-run-converted-model-st: embedding-run-converted-model embedding-verify-logits: embedding-run-original-model embedding-run-converted-model @./scripts/embedding/compare-embeddings-logits.sh \ $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") +embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st + @./scripts/embedding/compare-embeddings-logits.sh \ + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + embedding-inspect-original-model: $(call validate_embedding_model_path,embedding-inspect-original-model) @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH} diff --git a/examples/model-conversion/README.md b/examples/model-conversion/README.md index e95e05cd377cc..05d95d588bae7 100644 --- a/examples/model-conversion/README.md +++ b/examples/model-conversion/README.md @@ -189,6 +189,23 @@ This command will save two files to the `data` directory, one is a binary file containing logits which will be used for comparison with the converted model, and the other is a text file which allows for manual visual inspection. +#### Using SentenceTransformer with numbered layers +For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense, +03_Dense, 04_Normalize), use the `-st` targets to apply all these layers: + +```console +# Run original model with SentenceTransformer (applies all numbered layers) +(venv) $ make embedding-run-original-model-st + +# Run converted model with pooling enabled +(venv) $ make embedding-run-converted-model-st +``` + +This will use the SentenceTransformer library to load and run the model, which +automatically applies all the numbered layers in the correct order. This is +particularly useful when comparing with models that should include these +additional transformation layers beyond just the base model output. + ### Model conversion After updates have been made to [gguf-py](../../gguf-py) to add support for the new model the model can be converted to GGUF format using the following command: @@ -208,6 +225,13 @@ was done manually in the previous steps) and compare the logits: (venv) $ make embedding-verify-logits ``` +For models with SentenceTransformer layers, use the `-st` verification target: +```console +(venv) $ make embedding-verify-logits-st +``` +This convenience target automatically runs both the original model with SentenceTransformer +and the converted model with pooling enabled, then compares the results. + ### llama-server verification To verify that the converted model works with llama-server, the following command can be used: diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp index 6dc334189f4be..bbd095e6034cc 100644 --- a/examples/model-conversion/logits.cpp +++ b/examples/model-conversion/logits.cpp @@ -1,4 +1,7 @@ #include "llama.h" +#include "common.h" + + #include #include #include @@ -8,7 +11,10 @@ static void print_usage(int, char ** argv) { printf("\nexample usage:\n"); - printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [prompt]\n", argv[0]); + printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm ] [prompt]\n", argv[0]); + printf("\n"); + printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n"); + printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n"); printf("\n"); } @@ -17,6 +23,8 @@ int main(int argc, char ** argv) { std::string prompt = "Hello, my name is"; int ngl = 0; bool embedding_mode = false; + bool pooling_enabled = false; + int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) { int i = 1; @@ -41,9 +49,13 @@ int main(int argc, char ** argv) { return 1; } } else if (strcmp(argv[i], "-embd-mode") == 0) { + embedding_mode = true; + } else if (strcmp(argv[i], "-pooling") == 0) { + pooling_enabled = true; + } else if (strcmp(argv[i], "-embd-norm") == 0) { if (i + 1 < argc) { try { - embedding_mode = true; + embd_norm = std::stoi(argv[++i]); } catch (...) { print_usage(argc, argv); return 1; @@ -112,7 +124,7 @@ int main(int argc, char ** argv) { ctx_params.no_perf = false; if (embedding_mode) { ctx_params.embeddings = true; - ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE; + ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE; ctx_params.n_ubatch = ctx_params.n_batch; } @@ -143,17 +155,27 @@ int main(int argc, char ** argv) { return 1; } - float * logits; - int n_logits; + float * data_ptr; + int data_size; const char * type; + std::vector embd_out; if (embedding_mode) { - logits = llama_get_embeddings(ctx); - n_logits = llama_model_n_embd(model) * batch.n_tokens; + const int n_embd = llama_model_n_embd(model); + const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens; + const int n_embeddings = n_embd * n_embd_count; + float * embeddings; type = "-embeddings"; - const int n_embd = llama_model_n_embd(model); - const int n_embd_count = batch.n_tokens; + if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) { + embeddings = llama_get_embeddings_seq(ctx, 0); + embd_out.resize(n_embeddings); + printf("Normalizing embeddings using norm: %d\n", embd_norm); + common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm); + embeddings = embd_out.data(); + } else { + embeddings = llama_get_embeddings(ctx); + } printf("Embedding dimension: %d\n", n_embd); printf("\n"); @@ -164,7 +186,7 @@ int main(int argc, char ** argv) { // Print first 3 values for (int i = 0; i < 3 && i < n_embd; i++) { - printf("%9.6f ", logits[j * n_embd + i]); + printf("%9.6f ", embeddings[j * n_embd + i]); } printf(" ... "); @@ -172,7 +194,7 @@ int main(int argc, char ** argv) { // Print last 3 values for (int i = n_embd - 3; i < n_embd; i++) { if (i >= 0) { - printf("%9.6f ", logits[j * n_embd + i]); + printf("%9.6f ", embeddings[j * n_embd + i]); } } @@ -180,27 +202,33 @@ int main(int argc, char ** argv) { } printf("\n"); - printf("Embeddings size: %d\n", n_logits); + printf("Embeddings size: %d\n", n_embeddings); + + data_ptr = embeddings; + data_size = n_embeddings; } else { - logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - n_logits = llama_vocab_n_tokens(vocab); + float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + const int n_logits = llama_vocab_n_tokens(vocab); type = ""; printf("Vocab size: %d\n", n_logits); + + data_ptr = logits; + data_size = n_logits; } std::filesystem::create_directory("data"); - // Save logits to binary file + // Save data to binary file char bin_filename[512]; snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type); - printf("Saving logits to %s\n", bin_filename); + printf("Saving data to %s\n", bin_filename); FILE * f = fopen(bin_filename, "wb"); if (f == NULL) { fprintf(stderr, "%s: error: failed to open binary output file\n", __func__); return 1; } - fwrite(logits, sizeof(float), n_logits, f); + fwrite(data_ptr, sizeof(float), data_size, f); fclose(f); // Also save as text for debugging @@ -211,27 +239,27 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: error: failed to open text output file\n", __func__); return 1; } - for (int i = 0; i < n_logits; i++) { - fprintf(f, "%d: %.6f\n", i, logits[i]); + for (int i = 0; i < data_size; i++) { + fprintf(f, "%d: %.6f\n", i, data_ptr[i]); } fclose(f); if (!embedding_mode) { printf("First 10 logits: "); - for (int i = 0; i < 10 && i < n_logits; i++) { - printf("%.6f ", logits[i]); + for (int i = 0; i < 10 && i < data_size; i++) { + printf("%.6f ", data_ptr[i]); } printf("\n"); printf("Last 10 logits: "); - for (int i = n_logits - 10; i < n_logits; i++) { - if (i >= 0) printf("%.6f ", logits[i]); + for (int i = data_size - 10; i < data_size; i++) { + if (i >= 0) printf("%.6f ", data_ptr[i]); } printf("\n\n"); } - printf("Logits saved to %s\n", bin_filename); - printf("Logits saved to %s\n", txt_filename); + printf("Data saved to %s\n", bin_filename); + printf("Data saved to %s\n", txt_filename); llama_free(ctx); llama_model_free(model); diff --git a/examples/model-conversion/requirements.txt b/examples/model-conversion/requirements.txt index ac9f69e10bcc9..229b2ec75b75b 100644 --- a/examples/model-conversion/requirements.txt +++ b/examples/model-conversion/requirements.txt @@ -4,3 +4,4 @@ torchvision transformers huggingface-hub accelerate +sentence-transformers diff --git a/examples/model-conversion/scripts/embedding/convert-model.sh b/examples/model-conversion/scripts/embedding/convert-model.sh index 0929e42413e67..9926350c072b2 100755 --- a/examples/model-conversion/scripts/embedding/convert-model.sh +++ b/examples/model-conversion/scripts/embedding/convert-model.sh @@ -2,6 +2,21 @@ set -e +# Parse command line arguments +SENTENCE_TRANSFORMERS="" +while [[ $# -gt 0 ]]; do + case $1 in + -st|--sentence-transformers) + SENTENCE_TRANSFORMERS="--sentence-transformers-dense-modules" + shift + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + MODEL_NAME="${MODEL_NAME:-$(basename "$EMBEDDING_MODEL_PATH")}" OUTPUT_DIR="${OUTPUT_DIR:-../../models}" TYPE="${OUTTYPE:-f16}" @@ -15,7 +30,8 @@ echo "Converted model path:: ${CONVERTED_MODEL}" python ../../convert_hf_to_gguf.py --verbose \ ${EMBEDDING_MODEL_PATH} \ --outfile ${CONVERTED_MODEL} \ - --outtype ${TYPE} + --outtype ${TYPE} \ + ${SENTENCE_TRANSFORMERS} echo "" echo "The environment variable CONVERTED_EMBEDDING MODEL can be set to this path using:" diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh index f3e2676632070..0f490e6c3b20a 100755 --- a/examples/model-conversion/scripts/embedding/run-converted-model.sh +++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -5,6 +5,7 @@ set -e # Parse command line arguments CONVERTED_MODEL="" PROMPTS_FILE="" +USE_POOLING="" while [[ $# -gt 0 ]]; do case $1 in @@ -12,6 +13,10 @@ while [[ $# -gt 0 ]]; do PROMPTS_FILE="$2" shift 2 ;; + --pooling) + USE_POOLING="1" + shift + ;; *) if [ -z "$CONVERTED_MODEL" ]; then CONVERTED_MODEL="$1" @@ -47,4 +52,8 @@ echo $CONVERTED_MODEL cmake --build ../../build --target llama-logits -j8 # TODO: update logits.cpp to accept a --file/-f option for the prompt -../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT" +if [ -n "$USE_POOLING" ]; then + ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT" +else + ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT" +fi diff --git a/examples/model-conversion/scripts/embedding/run-original-model.py b/examples/model-conversion/scripts/embedding/run-original-model.py index 4a3e162413fa6..640e200a97dc3 100755 --- a/examples/model-conversion/scripts/embedding/run-original-model.py +++ b/examples/model-conversion/scripts/embedding/run-original-model.py @@ -14,6 +14,8 @@ parser = argparse.ArgumentParser(description='Process model with specified path') parser.add_argument('--model-path', '-m', help='Path to the model') parser.add_argument('--prompts-file', '-p', help='Path to file containing prompts (one per line)') +parser.add_argument('--use-sentence-transformers', action='store_true', + help='Use SentenceTransformer to apply all numbered layers (01_Pooling, 02_Dense, 03_Dense, 04_Normalize)') args = parser.parse_args() def read_prompt_from_file(file_path): @@ -31,41 +33,52 @@ def read_prompt_from_file(file_path): if model_path is None: parser.error("Model path must be specified either via --model-path argument or EMBEDDING_MODEL_PATH environment variable") -tokenizer = AutoTokenizer.from_pretrained(model_path) +# Determine if we should use SentenceTransformer +use_sentence_transformers = args.use_sentence_transformers or os.environ.get('USE_SENTENCE_TRANSFORMERS', '').lower() in ('1', 'true', 'yes') -config = AutoConfig.from_pretrained(model_path) - -# This can be used to override the sliding window size for manual testing. This -# can be useful to verify the sliding window attention mask in the original model -# and compare it with the converted .gguf model. -if hasattr(config, 'sliding_window'): - original_sliding_window = config.sliding_window - #original_sliding_window = 6 - print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}") - -print(f"Using unreleased model: {unreleased_model_name}") -if unreleased_model_name: - model_name_lower = unreleased_model_name.lower() - unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" - class_name = f"{unreleased_model_name}Model" - print(f"Importing unreleased model module: {unreleased_module_path}") - - try: - model_class = getattr(importlib.import_module(unreleased_module_path), class_name) - model = model_class.from_pretrained(model_path, config=config) - except (ImportError, AttributeError) as e: - print(f"Failed to import or load model: {e}") - exit(1) +if use_sentence_transformers: + from sentence_transformers import SentenceTransformer + print("Using SentenceTransformer to apply all numbered layers") + model = SentenceTransformer(model_path) + tokenizer = model.tokenizer + config = model[0].auto_model.config # type: ignore else: - model = AutoModel.from_pretrained(model_path, config=config) -print(f"Model class: {type(model)}") -print(f"Model file: {type(model).__module__}") + tokenizer = AutoTokenizer.from_pretrained(model_path) + + config = AutoConfig.from_pretrained(model_path) + + # This can be used to override the sliding window size for manual testing. This + # can be useful to verify the sliding window attention mask in the original model + # and compare it with the converted .gguf model. + if hasattr(config, 'sliding_window'): + original_sliding_window = config.sliding_window + #original_sliding_window = 6 + print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}") + + print(f"Using unreleased model: {unreleased_model_name}") + if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + class_name = f"{unreleased_model_name}Model" + print(f"Importing unreleased model module: {unreleased_module_path}") + + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained(model_path, config=config) + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") + exit(1) + else: + model = AutoModel.from_pretrained(model_path, config=config) + print(f"Model class: {type(model)}") + print(f"Model file: {type(model).__module__}") # Verify the model is using the correct sliding window -if hasattr(model.config, 'sliding_window'): - print(f"Model's sliding_window: {model.config.sliding_window}") -else: - print("Model config does not have sliding_window attribute") +if not use_sentence_transformers: + if hasattr(model.config, 'sliding_window'): # type: ignore + print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore + else: + print("Model config does not have sliding_window attribute") model_name = os.path.basename(model_path) @@ -75,34 +88,56 @@ def read_prompt_from_file(file_path): else: texts = ["Hello world today"] -encoded = tokenizer( - texts, - padding=True, - truncation=True, - return_tensors="pt" -) - -tokens = encoded['input_ids'][0] -token_strings = tokenizer.convert_ids_to_tokens(tokens) -for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): - print(f"{token_id:6d} -> '{token_str}'") - with torch.no_grad(): - outputs = model(**encoded) - hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size] - - # Extract embeddings for each token (matching LLAMA_POOLING_TYPE_NONE behavior) - all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size] - - print(f"Hidden states shape: {hidden_states.shape}") - print(f"All embeddings shape: {all_embeddings.shape}") - print(f"Embedding dimension: {all_embeddings.shape[1]}") - - # Print embeddings exactly like embedding.cpp does for LLAMA_POOLING_TYPE_NONE - n_embd = all_embeddings.shape[1] - n_embd_count = all_embeddings.shape[0] - - print() # Empty line to match C++ output + if use_sentence_transformers: + embeddings = model.encode(texts, convert_to_numpy=True) + all_embeddings = embeddings # Shape: [batch_size, hidden_size] + + encoded = tokenizer( + texts, + padding=True, + truncation=True, + return_tensors="pt" + ) + tokens = encoded['input_ids'][0] + token_strings = tokenizer.convert_ids_to_tokens(tokens) + for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): + print(f"{token_id:6d} -> '{token_str}'") + + print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}") + print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore + else: + # Standard approach: use base model output only + encoded = tokenizer( + texts, + padding=True, + truncation=True, + return_tensors="pt" + ) + + tokens = encoded['input_ids'][0] + token_strings = tokenizer.convert_ids_to_tokens(tokens) + for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): + print(f"{token_id:6d} -> '{token_str}'") + + outputs = model(**encoded) + hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size] + + all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size] + + print(f"Hidden states shape: {hidden_states.shape}") + print(f"All embeddings shape: {all_embeddings.shape}") + print(f"Embedding dimension: {all_embeddings.shape[1]}") + + if len(all_embeddings.shape) == 1: + n_embd = all_embeddings.shape[0] # type: ignore + n_embd_count = 1 + all_embeddings = all_embeddings.reshape(1, -1) + else: + n_embd = all_embeddings.shape[1] # type: ignore + n_embd_count = all_embeddings.shape[0] # type: ignore + + print() for j in range(n_embd_count): embedding = all_embeddings[j] @@ -120,29 +155,23 @@ def read_prompt_from_file(file_path): print() # New line - print() # Final empty line to match C++ output + print() data_dir = Path("data") data_dir.mkdir(exist_ok=True) bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin" txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt" - # Save all embeddings flattened (matching what embedding.cpp would save if it did) flattened_embeddings = all_embeddings.flatten() flattened_embeddings.astype(np.float32).tofile(bin_filename) with open(txt_filename, "w") as f: - f.write(f"# Model class: {model_name}\n") - f.write(f"# Tokens: {token_strings}\n") - f.write(f"# Shape: {all_embeddings.shape}\n") - f.write(f"# n_embd_count: {n_embd_count}, n_embd: {n_embd}\n\n") - + idx = 0 for j in range(n_embd_count): - f.write(f"# Token {j} ({token_strings[j]}):\n") - for i, value in enumerate(all_embeddings[j]): - f.write(f"{j}_{i}: {value:.6f}\n") - f.write("\n") - print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} tokens × {n_embd} dimensions)") + for value in all_embeddings[j]: + f.write(f"{idx}: {value:.6f}\n") + idx += 1 + print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)") print("") print(f"Saved bin embeddings to: {bin_filename}") print(f"Saved txt embeddings to: {txt_filename}") diff --git a/examples/model-conversion/scripts/utils/semantic_check.py b/examples/model-conversion/scripts/utils/semantic_check.py index 7fd417bceaa8b..2ac8b6b7b42cb 100644 --- a/examples/model-conversion/scripts/utils/semantic_check.py +++ b/examples/model-conversion/scripts/utils/semantic_check.py @@ -35,7 +35,11 @@ def cosine_similarity(a, b=None): def load_embeddings_from_file(filename, n_tokens, n_embd): embeddings = np.fromfile(filename, dtype=np.float32) - return embeddings.reshape(n_tokens, n_embd) + # Check if this is pooled (single embedding) or per-token embeddings + if len(embeddings) == n_embd: + return embeddings.reshape(1, n_embd) + else: + return embeddings.reshape(n_tokens, n_embd) def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt): np.set_printoptions(suppress=True, precision=6) @@ -48,58 +52,83 @@ def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt): print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}") n_tokens = len(tokens) + is_pooled = python_emb.shape[0] == 1 + + if is_pooled: + print(f"\n[Pooled Embeddings Mode - comparing single sentence embeddings]") - # 1. Direct embedding comparison - print(f"\n1. Raw Embedding Magnitude Comparison:") - # Check if the distance of each token embedding from the origin and compare - # if the vectors are on the same "sphere". This does not tell us about - # direction (meaning of the token embedding), just magnitude. - for i in range(n_tokens): - py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings - cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings + # 1. Direct embedding comparison for pooled embeddings + print(f"\n1. Raw Embedding Magnitude Comparison:") + py_mag = np.linalg.norm(python_emb[0]) + cpp_mag = np.linalg.norm(cpp_emb[0]) ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf') - print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") - - # 2. Cosine similarity between tokens within each model - # Here we check the direction of token embeddings to see if the have the - # same meaning (similarity). This is done by calculating cosine similarity - # of a pair of token embeddings within each model. - print(f"\n2. Within-Model Token Similarities:") - print(" Python model:") - for i in range(n_tokens): - for j in range(i+1, n_tokens): - sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0] - print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") - - print(" llama.cpp model:") - for i in range(n_tokens): - for j in range(i+1, n_tokens): - sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0] - print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") - - # 3. Cross-model similarity (same token position) - print(f"\n3. Cross-Model Same-Token Similarities:") - for i in range(n_tokens): - sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] - print(f" Token {i} ({tokens[i]}): {sim:.4f}") - - # 4. Similarity matrix comparison - print(f"\n4. Similarity Matrix Differences:") - py_sim_matrix = cosine_similarity(python_emb) - cpp_sim_matrix = cosine_similarity(cpp_emb) - diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix) - - print(f" Max difference: {np.max(diff_matrix):.4f}") - print(f" Mean difference: {np.mean(diff_matrix):.4f}") - print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}") - - return { - 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)], - 'similarity_matrix_diff': diff_matrix, - 'max_diff': np.max(diff_matrix), - 'mean_diff': np.mean(diff_matrix), - 'rms_diff': np.sqrt(np.mean(diff_matrix**2)) - } + print(f" Pooled embedding: Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") + + # 2. Cross-model similarity for pooled embeddings + print(f"\n2. Cross-Model Pooled Embedding Similarity:") + sim = cosine_similarity([python_emb[0]], [cpp_emb[0]])[0][0] + print(f" Cosine similarity: {sim:.6f}") + + return { + 'cross_model_similarities': [sim], + 'similarity_matrix_diff': np.array([[0.0]]), + 'max_diff': 0.0, + 'mean_diff': 0.0, + 'rms_diff': 0.0 + } + else: + # Original per-token comparison logic + # 1. Direct embedding comparison + print(f"\n1. Raw Embedding Magnitude Comparison:") + # Check if the distance of each token embedding from the origin and compare + # if the vectors are on the same "sphere". This does not tell us about + # direction (meaning of the token embedding), just magnitude. + for i in range(n_tokens): + py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings + cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings + ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf') + print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") + + # 2. Cosine similarity between tokens within each model + # Here we check the direction of token embeddings to see if the have the + # same meaning (similarity). This is done by calculating cosine similarity + # of a pair of token embeddings within each model. + print(f"\n2. Within-Model Token Similarities:") + print(" Python model:") + for i in range(n_tokens): + for j in range(i+1, n_tokens): + sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0] + print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") + + print(" llama.cpp model:") + for i in range(n_tokens): + for j in range(i+1, n_tokens): + sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0] + print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") + + # 3. Cross-model similarity (same token position) + print(f"\n3. Cross-Model Same-Token Similarities:") + for i in range(n_tokens): + sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] + print(f" Token {i} ({tokens[i]}): {sim:.4f}") + + # 4. Similarity matrix comparison + print(f"\n4. Similarity Matrix Differences:") + py_sim_matrix = cosine_similarity(python_emb) + cpp_sim_matrix = cosine_similarity(cpp_emb) + diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix) + + print(f" Max difference: {np.max(diff_matrix):.4f}") + print(f" Mean difference: {np.mean(diff_matrix):.4f}") + print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}") + + return { + 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)], + 'similarity_matrix_diff': diff_matrix, + 'max_diff': np.max(diff_matrix), + 'mean_diff': np.mean(diff_matrix), + 'rms_diff': np.sqrt(np.mean(diff_matrix**2)) + } def read_prompt_from_file(file_path): try: diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index b707b843593c7..debbcadc1e4c5 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -341,11 +341,18 @@ class cann_task_queue { #ifdef USE_ACL_GRAPH struct ggml_graph_node_properties { + // dst tensor void * node_address; - ggml_op node_op; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; + + // src tensor void * src_address[GGML_MAX_SRC]; + int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; + + // op + ggml_op node_op; int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; }; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index b51b554e752e1..ad1adba6b3a8a 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties( std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); for (int src = 0; src < GGML_MAX_SRC; ++src) { - prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr; + if (node->src[src]) { + prop.src_address[src] = node->src[src]->data; + std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]); + std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]); + } else { + prop.src_address[src] = nullptr; + std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0); + std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0); + } } memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS); @@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties( * @param graph_node_properties The stored properties of a CANN graph node. * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise. */ -static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { +static bool ggml_graph_node_has_matching_properties( + ggml_tensor * node, + ggml_graph_node_properties * graph_node_properties) { if (node->data != graph_node_properties->node_address && - node->op != GGML_OP_VIEW) { + node->op != GGML_OP_VIEW) { return false; } + if (node->op != graph_node_properties->node_op) { return false; } + for (int i = 0; i < GGML_MAX_DIMS; i++) { if (node->ne[i] != graph_node_properties->ne[i]) { return false; @@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return false; } } + for (int i = 0; i < GGML_MAX_SRC; i++) { - if (node->src[i] && - node->src[i]->data != graph_node_properties->src_address[i] && - node->op != GGML_OP_VIEW - ) { - return false; + if (node->src[i]) { + if (node->src[i]->data != graph_node_properties->src_address[i] && + node->op != GGML_OP_VIEW) { + return false; + } + + for (int d = 0; d < GGML_MAX_DIMS; d++) { + if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) { + return false; + } + if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) { + return false; + } + } + } else { + if (graph_node_properties->src_address[i] != nullptr) { + return false; + } } } - if (node->op == GGML_OP_SCALE && - memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { - return false; + + if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) { + return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; } return true; } diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index 7ba659124ca27..3eaa5e3f4100f 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -29,6 +29,108 @@ #define NELEMS(x) sizeof(x) / sizeof(*x) +template +static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) { + return Fn(a, b, c); +} + +template +static inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) { + return Fn(a, b); +} + +template +static inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl, + const void* lhs, const void* rhs, void* dst, + size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max) { + Fn(m, n, k, bl, lhs, rhs, static_cast(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max); +} + +template +static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/, + const void* lhs, const void* rhs, void* dst, + size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max) { + Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max); +} + +template +static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + return Fn(m, k, bl, mr, kr, sr); +} + +template +static inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) { + return Fn(m, k, mr, kr, sr); +} + +template +static inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + return Fn(m_idx, k, bl, mr, kr, sr); +} + +template +static inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) { + return Fn(m_idx, k, mr, kr, sr); +} + +template +static inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { + Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast(lhs), lhs_stride, lhs_packed); +} + +template +static inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { + Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed); +} + +template +static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { + Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed); +} + +template +static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) { + return Fn(n, k, nr, kr, bl); +} + +template +static inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) { + return Fn(n, k); +} + +template +static inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) { + return Fn(k, nr, kr, bl); +} + +template +static inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) { + return Fn(k); +} + +template +static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, + size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/, + void* rhs_packed, size_t extra_bytes, const void* params) { + Fn(num_groups, n, k, nr, kr, sr, bl, + static_cast(rhs), + static_cast(bias), + rhs_packed, extra_bytes, + static_cast(params)); +} + +template +static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/, + size_t rhs_stride, const void* rhs, const void* bias, const void* scale, + void* rhs_packed, size_t extra_bytes, const void* params) { + Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params); +} + static const size_t INT4_PER_BYTE = 2; static const size_t INT4_BITS = 4; static const int Q4_0_ZERO_POINT = 8; @@ -122,17 +224,18 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, + /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* SME GEMV */ /* .kern_info = */ { @@ -142,23 +245,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, @@ -174,17 +278,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset_ex = */ &kernel_offs_fn2, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2, + /* .run_kernel_ex = */ &kernel_run_fn10, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, - /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_void_fn9, }, /* SME GEMV */ /* .kern_info = */ { @@ -194,23 +298,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset_ex = */ nullptr, + /* .get_rhs_packed_offset_ex = */ nullptr, + /* .run_kernel_ex = */ nullptr, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, - /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_void_fn9, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - /* .packed_stride = */ NULL, - /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - /* .to_float = */ NULL, + /* .packed_stride = */ nullptr, + /* .to_float = */ nullptr, + /* .packed_size_ex = */ &rhs_ps_fn2, + /* .packed_stride_ex = */ &rhs_stride_fn1, + /* .pack_func_ex = */ &rhs_pack_fn13, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, @@ -229,17 +334,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* DOTPROD GEMV */ /* .kern_info = */ { @@ -249,23 +354,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, @@ -283,17 +389,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* i8mm GEMV */ /* .kern_info = */ { @@ -303,23 +409,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, @@ -338,17 +445,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* i8mm GEMV */ /* .kern_info = */ { @@ -358,23 +465,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, @@ -392,17 +500,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* DOTPROD GEMV */ /* .kern_info = */ { @@ -412,23 +520,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, @@ -443,6 +552,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c ggml_kleidiai_kernels * kernel = nullptr; if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && @@ -452,6 +562,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c break; } } +#endif } return kernel; @@ -460,12 +571,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { ggml_kleidiai_kernels * kernels = nullptr; +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { kernels = &gemm_gemv_kernels[i]; break; } } +#endif return kernels; } diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index 2ad6ad6fd0bfc..a84795a6b2e50 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -4,8 +4,6 @@ #pragma once -#include -#include #include "ggml.h" enum cpu_feature { @@ -15,6 +13,7 @@ enum cpu_feature { CPU_FEATURE_SVE = 4, CPU_FEATURE_SME = 8 }; + inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) { lhs = static_cast(lhs | rhs); return lhs; @@ -30,63 +29,52 @@ struct kernel_info { size_t (*get_nr)(void); size_t (*get_kr)(void); size_t (*get_sr)(void); - std::variant< - std::function, - std::function - > get_lhs_offset; - std::variant< - std::function, - std::function - > get_rhs_packed_offset; + size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); size_t (*get_dst_size)(size_t m, size_t n); - std::variant< - std::function, - std::function - > run_kernel; + + size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl); + + size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl); + + void (*run_kernel_ex)( + size_t m, size_t n, size_t k, size_t bl, + const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max); }; struct lhs_packing_info { size_t (*get_offset)(size_t m_idx, size_t lhs_stride); - std::variant< - std::function, - std::function - > get_packed_offset; - std::variant< - std::function, - std::function - > packed_size; - std::variant< - std::function, - std::function - > pack_func; + + size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + + size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + + void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed); }; struct rhs_packing_info { - std::variant< - std::function, - std::function - > packed_size; size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl); - std::variant< - std::function, - std::function - > pack_func; - void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride, - size_t kr, size_t bl, size_t num_bytes_multiplier); + + void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, + size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl, + size_t num_bytes_multiplier); + + size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); + + size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl); + + void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, + size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params); }; struct ggml_kleidiai_kernels { - kernel_info gemm; + kernel_info gemm; lhs_packing_info gemm_lhs_info; - kernel_info gemv; + kernel_info gemv; lhs_packing_info gemv_lhs_info; rhs_packing_info rhs_info; diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 44691e5dfdf6a..8b3df7d78009e 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #if defined(__linux__) #include #include @@ -87,40 +88,6 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { return tensor->ne[dim]; } -template -constexpr bool variant_any_invocable_impl(std::index_sequence) { - using V = std::remove_reference_t; - return (std::is_invocable_r_v< - Ret, - std::variant_alternative_t, - Args...> || ...); -} - -template -constexpr bool variant_any_invocable_v = - variant_any_invocable_impl( - std::make_index_sequence< - std::variant_size_v>>{}); - -template -static inline Ret variant_call(Variant && var, Args&&... args) { - static_assert(variant_any_invocable_v, Ret, Args...>, - "No alternative in Variant is invocable with the provided arguments and return type."); - - return std::visit( - [&](auto && f) -> Ret { - using F = std::decay_t; - if constexpr (std::is_invocable_r_v) { - return std::invoke(std::forward(f), std::forward(args)...); - } else { - GGML_ABORT("Invalid function type in variant_call"); - GGML_UNREACHABLE(); - } - }, - std::forward(var) - ); -} - namespace ggml::cpu::kleidiai { static size_t round_down(size_t x, size_t y) { @@ -145,7 +112,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { return false; } ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); - GGML_ASSERT(kernels); + if (!kernels) { + return false; + } bool is_gemv = op->src[1]->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; @@ -159,16 +128,18 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t sr = kernel->get_sr(); if (kernels->rhs_type == GGML_TYPE_Q4_0) { - size = variant_call(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr); + if (!lhs_info->packed_size_ex) return false; + size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr); } else if (kernels->rhs_type == GGML_TYPE_F16) { + if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false; const int64_t lhs_batch_size0 = op->src[1]->ne[2]; const int64_t rhs_batch_size0 = op->src[0]->ne[2]; const int64_t r = lhs_batch_size0 / rhs_batch_size0; - size = variant_call(lhs_info->packed_size, m * r, k, mr, kr, sr) + - variant_call(kernels->rhs_info.packed_size, n, k) + + size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) + + kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) + k * n * sizeof(float) + n * sizeof(float); } else { - GGML_ASSERT(false); + return false; } return true; @@ -196,12 +167,18 @@ class tensor_traits : public ggml::cpu::tensor_traits { GGML_TENSOR_BINARY_OP_LOCALS ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - GGML_ASSERT(kernels); + if (!kernels) { + return false; + } const bool is_gemv = src1->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; GGML_ASSERT(kernel); + if (!kernels->rhs_info.pack_func_ex || + !kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) { + return false; + } const int nth = params->nth; const int ith = params->ith; @@ -228,10 +205,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t kr = (int64_t) kernel->get_kr(); const int64_t sr = (int64_t) kernel->get_sr(); - const size_t lhs_packed_size = variant_call(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); - const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, (size_t)n, (size_t)k); - const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float); - const size_t bias_size = (size_t)n * sizeof(float); + const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr); + const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0); + const size_t kxn_size = k * n * sizeof(float); + const size_t bias_size = n * sizeof(float); const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; GGML_ASSERT(wsize_required <= params->wsize); @@ -259,10 +236,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; // Base packed offset (aligned) and per-row stride in bytes - const size_t base_packed_off = variant_call( - lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); - const size_t next_block_off = variant_call( - lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t base_packed_off = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr); + const size_t next_block_off = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr); const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr; int64_t remaining = m_count; @@ -278,9 +253,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; void * dst_ptr = lhs_packed + dst_off; - variant_call(lhs_info->pack_func, - (size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr, - /*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr); + lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); cur += take; remaining -= take; @@ -296,10 +269,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { reinterpret_cast(rhs_batch_base), rhs_stride); - variant_call(kernels->rhs_info.pack_func, - /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr, - /*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)), - rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr); + kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float), + rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); } ggml_barrier(params->threadpool); @@ -320,20 +291,15 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0; // LHS packed base at row 0 (consistent with packing above) - const size_t lhs_packed_offset0 = variant_call( - lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); - const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k); - const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride); + const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr); + const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0); + const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride); const void * lhs_ptr = lhs_packed + lhs_packed_offset0; const void * rhs_ptr = rhs_packed + rhs_packed_offset; float * dst_ptr = reinterpret_cast(dst_batch_base + dst_offset); - variant_call(kernel->run_kernel, - (size_t)m, (size_t)n_to_process, (size_t)k, - lhs_ptr, rhs_ptr, - dst_ptr, dst_stride, sizeof(float), - -FLT_MAX, FLT_MAX); + kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); } } @@ -354,13 +320,19 @@ class tensor_traits : public ggml::cpu::tensor_traits { GGML_TENSOR_BINARY_OP_LOCALS ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - GGML_ASSERT(kernels); + if (!kernels) { + return false; + } bool is_gemv = src1->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; GGML_ASSERT(kernel); + if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex || + !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) { + return false; + } const int ith = params->ith; const int nth_raw = params->nth; @@ -402,25 +374,26 @@ class tensor_traits : public ggml::cpu::tensor_traits { // Transform LHS const size_t src_stride = src1->nb[1]; const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr); + const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr); void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); - variant_call(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer + lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); } ggml_barrier(params->threadpool); // Perform the operation const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr); - const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k, QK4_0); + const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr); + const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0); const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); if (n_to_process > 0) { - variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, + kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); } @@ -429,7 +402,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0); - GGML_ASSERT(ctx.kernels); + if (!ctx.kernels) { + return false; + } const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -438,6 +413,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { rhs_packing_info * rhs_info = &ctx.kernels->rhs_info; kernel_info * kernel = &ctx.kernels->gemm; + if (!rhs_info->to_float || !kernel->get_nr) { + return false; + } const int64_t nc = ne00; const int64_t nr = ggml_nelements(src1); @@ -480,7 +458,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { struct kai_rhs_pack_qs4cxs1s0_param params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; - variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms); + ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, ¶ms); return 0; GGML_UNUSED(data_size); @@ -548,7 +526,7 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_ const size_t nr = ctx.kernels->gemm.get_nr(); const size_t kr = ctx.kernels->gemm.get_kr(); - return variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); + return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0); GGML_UNUSED(buft); } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 8e1a2de14f983..1c43865ff65fc 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3467,31 +3467,27 @@ static void ggml_compute_forward_norm_f32( GGML_ASSERT(eps >= 0.0f); - // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)x[i00]; - } - + float sum = 0.0; + ggml_vec_sum_f32(ne00, &sum, x); float mean = sum/ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + float variance = 0; - ggml_float sum2 = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sum2 += (ggml_float)(v*v); - } +#ifdef GGML_USE_ACCELERATE + mean = -mean; + vDSP_vsadd(x, 1, &mean, y, 1, ne00); + vDSP_measqv(y, 1, &variance, ne00); +#else + variance = ggml_vec_cvar_f32(ne00, y, x, mean); +#endif //GGML_USE_ACCELERATE - float variance = sum2/ne00; const float scale = 1.0f/sqrtf(variance + eps); - ggml_vec_scale_f32(ne00, y, scale); } } diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 437192d525a34..b8e37052d35e1 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -404,6 +404,72 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * } } +ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) { + int i = 0; + ggml_float sum = 0; +// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE +// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344 +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i), + _mm512_set1_ps(mean)); + _mm512_storeu_ps(y + i, val); + sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val)); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i), + _mm256_set1_ps(mean)); + _mm256_storeu_ps(y + i, val); + val = _mm256_mul_ps(val,val); + __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), + _mm256_castps256_ps128(val)); + val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); + val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); + sum += (ggml_float)_mm_cvtss_f32(val2); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i), + _mm_set1_ps(mean)); + _mm_storeu_ps(y + i, val); + val = _mm_mul_ps(val, val); +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + val = _mm_add_ps(val, _mm_movehl_ps(val, val)); + val = _mm_add_ss(val, _mm_movehdup_ps(val)); +#else + __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); + val = _mm_add_ps(val, tmp); + tmp = _mm_movehl_ps(tmp, val); + val = _mm_add_ss(val, tmp); +#endif // __AVX__ || __AVX2__ || __AVX512F__ + sum += (ggml_float)_mm_cvtss_f32(val); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + float32x4_t val = vsubq_f32(vld1q_f32(x + i), + vdupq_n_f32(mean)); + vst1q_f32(y + i, val); + val = vmulq_f32(val, val); + sum += (ggml_float)vaddvq_f32(val); + } +#elif defined(__VXE__) || defined(__VXE2__) + for (; i + 3 < n; i += 4) { + float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean)); + vec_xst(val, 0, y + i); + val = vec_mul(val, val); + sum += (ggml_float)vec_hsum_f32x4(val); + } +#endif + for (; i < n; ++i) { + float val = x[i] - mean; + val *= val; + sum += (ggml_float)val; + y[i] = val; + } + return sum/n; +} + ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index f95ca94e54b16..2751359ce49f4 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -44,6 +44,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); void ggml_vec_silu_f32(const int n, float * y, const float * x); +ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean ) ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max); ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max); diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 4e7449d06ecfe..d66d7ade90182 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -197,6 +197,7 @@ struct sycl_device_info { int cc; // compute capability // int nsm; // number of streaming multiprocessors // size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) bool vmm; // virtual memory support size_t total_vram; //sycl_hw_info hw_info; \\ device id and aarch, currently not used @@ -416,13 +417,6 @@ static __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) { #pragma unroll for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { - /* - DPCT1096:98: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the - CPU device. Modify the size of the work-group to ensure that the value - of the right-most dimension is a multiple of "32". - */ x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask); } return x; @@ -440,17 +434,67 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) { return a; } +template +static __dpct_inline__ int warp_reduce_sum(int x) { + return sycl::reduce_over_group( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>()); +} + +template +static __dpct_inline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, offset, width); + } + return x; +} + +template +static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + a.x() += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a.x(), offset, + width); + a.y() += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a.y(), offset, + width); + } + return a; +} + +template +static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + a = a + dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a, offset, + width); + } + return a; +} + +static constexpr int ggml_sycl_get_physical_warp_size() { + // todo: for old iGPU + dGPU case, need to be changed. + return WARP_SIZE; +} + +template +static __dpct_inline__ float warp_reduce_max(float x) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x = sycl::fmax(x, dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, + offset, width)); + } + return x; +} + static __dpct_inline__ float warp_reduce_max(float x, const sycl::nd_item<3>& item_ct1) { #pragma unroll for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { - /* - DPCT1096:97: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the - CPU device. Modify the size of the work-group to ensure that the value - of the right-most dimension is a multiple of "32". - */ x = sycl::fmax(x, dpct::permute_sub_group_by_xor( item_ct1.get_sub_group(), x, mask)); } @@ -558,4 +602,18 @@ struct scope_op_debug_print { std::string_view func_suffix; }; +static __dpct_inline__ float get_alibi_slope(const float max_bias, + const uint32_t h, + const uint32_t n_head_log2, + const float m0, + const float m1) { + if (max_bias <= 0.0f) { + return 1.0f; + } + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + return dpct::pow(base, exph); +} + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index d538965b096bf..f93cfa701f584 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -277,6 +277,26 @@ namespace dpct } // namespace detail + // COPY from DPCT head files + /// dim3 is used to store 3 component dimensions. + class dim3 { + public: + unsigned x, y, z; + + constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1) + : x(x), y(y), z(z) {} + + dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {} + + operator sycl::range<3>() const { return sycl::range<3>(z, y, x); } + }; // namespace dim3 + + inline dim3 operator*(const dim3 &a, const dim3 &b) { + return dim3{a.x * b.x, a.y * b.y, a.z * b.z}; + } + // COPY from DPCT head files + + /// Pitched 2D/3D memory data. class pitched_data { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 4ac919ea2d757..e4cc3c8ed8f2a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -87,6 +87,7 @@ static ggml_sycl_device_info ggml_sycl_init() { 100 * prop.get_major_version() + 10 * prop.get_minor_version(); info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); info.max_work_group_sizes[i] = prop.get_max_work_group_size(); + info.devices[i].smpbo = prop.get_local_mem_size(); } for (int id = 0; id < info.device_count; ++id) { @@ -3741,6 +3742,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_SOFT_MAX: ggml_sycl_op_soft_max(ctx, dst); break; + case GGML_OP_SOFT_MAX_BACK: + ggml_sycl_op_soft_max_back(ctx, dst); + break; case GGML_OP_ROPE: ggml_sycl_rope(ctx, dst); break; @@ -3778,6 +3782,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg return true; } catch (sycl::exception & e) { std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::cerr << "Error OP "<op)<< std::endl; std::exit(1); } @@ -4386,19 +4391,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; - case GGML_OP_SOFT_MAX: - // TODO: support batching - if (op->src[0]->ne[3] != 1) { - return false; - } - // TODO: support attention sinks [TAG_ATTN_SINKS] - if (op->src[2]) { - return false; - } - // TODO: support broadcast - // ref: https://github.com/ggml-org/llama.cpp/pull/14435 - return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); case GGML_OP_DIAG_MASK_INF: + return true; + case GGML_OP_SOFT_MAX: + return true; + case GGML_OP_SOFT_MAX_BACK: { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + return max_bias == 0.0f; + } case GGML_OP_ROPE: case GGML_OP_IM2COL: return true; diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 52fcf4b3dbd24..83b7c71b66194 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -1,37 +1,94 @@ #include "softmax.hpp" +#include +#include +#include -template -static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, - const int nrows_y, const float scale, const float max_bias, const float m0, - const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { - const int ncols = ncols_template == 0 ? ncols_par : ncols_template; - const int tid = item_ct1.get_local_id(2); - const int rowx = item_ct1.get_group(2); - const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension +template static __dpct_inline__ float t2f32(T val) { + return (float) val; +} - const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; +template <> float __dpct_inline__ t2f32(sycl::half val) { + return sycl::vec(val) + .convert()[0]; +} - const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; - const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; +struct soft_max_params { + + int64_t nheads; + uint32_t n_head_log2; + int64_t ncols; + int64_t nrows_x; + int64_t nrows_y; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + int64_t nb11; + int64_t nb12; + int64_t nb13; + + int64_t ne12; + int64_t ne13; + float scale; + float max_bias; + float m0; + float m1; +}; + +// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. +// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template +static void soft_max_f32(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params p, + uint8_t * dpct_local) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int ncols = ncols_template == 0 ? p.ncols : ncols_template; + const int block_size = block_size_template == 0 + ? item_ct1.get_local_range(2) + : block_size_template; const int nthreads = block_size; const int nwarps = nthreads / WARP_SIZE; size_t nreduce = nwarps / WARP_SIZE; - float slope = 1.0f; - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = rowx/nrows_y; // head index + const int tid = item_ct1.get_local_id(2); - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const int64_t i03 = item_ct1.get_group(0); + const int64_t i02 = item_ct1.get_group(1); + const int64_t i01 = item_ct1.get_group(2); - slope = sycl::pow(base, float(exp)); - } + //TODO: noncontigous inputs/outputs + const int rowx = item_ct1.get_group(2) + + item_ct1.get_group(1) * item_ct1.get_group_range(2) + + item_ct1.get_group(0) * item_ct1.get_group_range(2) * + item_ct1.get_group_range(1); + + const int64_t i11 = i01; + const int64_t i12 = i02 % p.ne12; + const int64_t i13 = i03 % p.ne13; - float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols; - float max_val = -INFINITY; + x += int64_t(rowx)*ncols; + mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); + dst += int64_t(rowx)*ncols; + const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + + const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); + + float * buf_iw = (float *) dpct_local; + + // shared memory buffer to cache values between iterations: + float *vals = use_shared ? buf_iw + sycl::max(nwarps, WARP_SIZE) : dst; + float max_val = sinks ? sinks[i02] : -INFINITY; +#pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; @@ -39,42 +96,35 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int break; } - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; - - const float val = x[ix]*scale + (mask ? slope*static_cast(mask[iy]) : 0.0f); + const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; - max_val = sycl::max(max_val, val); + max_val = sycl::max(max_val, val); } - // find the max value in the block - max_val = warp_reduce_max(max_val, item_ct1); + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { if (warp_id == 0) { - buf[lane_id] = -INFINITY; - for (size_t i = 1; i < nreduce; i += 1) { - buf[lane_id + i * WARP_SIZE] = -INFINITY; - } + buf_iw[lane_id] = -INFINITY; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (lane_id == 0) { - buf[warp_id] = max_val; + buf_iw[warp_id] = max_val; } - item_ct1.barrier(sycl::access::fence_space::local_space); - max_val = buf[lane_id]; - for (size_t i = 1; i < nreduce; i += 1) { - max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]); - } - max_val = warp_reduce_max(max_val, item_ct1); + item_ct1.barrier(); + + max_val = buf_iw[lane_id]; + max_val = warp_reduce_max(max_val); } + float tmp = 0.0f; // partial sum - float tmp = 0.f; #pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; - if (ncols_template == 0 && col >= ncols) { + + if (ncols_template == 0 && col >= ncols) { break; } @@ -82,32 +132,33 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int tmp += val; vals[col] = val; } - // find the sum of exps in the block - tmp = warp_reduce_sum(tmp, item_ct1); + tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (warp_id == 0) { - buf[lane_id] = 0.f; + buf_iw[lane_id] = 0.0f; for (size_t i = 1; i < nreduce; i += 1) { - buf[lane_id + i * WARP_SIZE] = 0.f; + buf_iw[lane_id + i * WARP_SIZE] = 0.f; } } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (lane_id == 0) { - buf[warp_id] = tmp; + buf_iw[warp_id] = tmp; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); - tmp = buf[lane_id]; + tmp = buf_iw[lane_id]; for (size_t i = 1; i < nreduce; i += 1) { - tmp += buf[lane_id + i * WARP_SIZE]; + tmp += buf_iw[lane_id + i * WARP_SIZE]; } - tmp = warp_reduce_sum(tmp, item_ct1); + tmp = warp_reduce_sum(tmp); } - - const float inv_sum = 1.f / tmp; + if (sinks) { + tmp += sycl::native::exp(sinks[i02] - max_val); + } + const float inv_sum = 1.0f / tmp; #pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { @@ -117,145 +168,259 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int return; } - const int idst = rowx*ncols + col; - dst[idst] = vals[col] * inv_sum; + dst[col] = vals[col] * inv_sum; } } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +static void soft_max_back_f32(const float *grad, const float *dstf, float *dst, + const int ncols, const float scale) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int tid = item_ct1.get_local_id(2); + const int rowx = item_ct1.get_group(2); + + grad += int64_t(rowx)*ncols; + dstf += int64_t(rowx)*ncols; + dst += int64_t(rowx)*ncols; + + float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dgf_dot += dstf[col]*grad[col]; + } + + dgf_dot = warp_reduce_sum(dgf_dot); + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; + } +} + +template +static void launch_soft_max_kernels(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params & p, + dpct::queue_ptr stream, + dpct::dim3 block_dims, + dpct::dim3 block_nums, + size_t nbytes_shared) +{ + auto launch_kernel = [=](auto I) -> bool { + constexpr int ncols = decltype(I)::value; + constexpr int block = (ncols > 1024 ? 1024 : ncols); + if (p.ncols == ncols) { + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + WARP_SIZE)]] { + soft_max_f32( + x, mask, sinks, dst, p, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); + }); + return true; + } + return false; + }; + + // unary fold over launch_kernel + if ((launch_kernel(std::integral_constant{}) || ...)) { + return; + } -template -static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par, - const int nrows_y, const float scale, const float max_bias, const float m0, - const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, - const size_t n_local_scratch, queue_ptr stream) { stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor local_buf_acc(n_local_scratch, cgh); + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared), cgh); cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - soft_max_f32(x, mask, dst, ncols_par, - nrows_y, scale, max_bias, m0, - m1, n_head_log2, item_ct1, - get_pointer(local_buf_acc)); - }); + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + soft_max_f32( + x, mask, sinks, dst, p, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); }); } -template -static void soft_max_f32_sycl(const float * x, const T * mask, - float * dst, const int ncols_x, const int nrows_x, - const int nrows_y, const float scale, const float max_bias, - queue_ptr stream, int device) { +template +static void soft_max_f32_sycl(const float *x, const T *mask, + const float *sinks, float *dst, + const soft_max_params ¶ms, + dpct::queue_ptr stream, int device) { int nth = WARP_SIZE; int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; + const int64_t ncols_x = params.ncols; + while (nth < ncols_x && nth < max_block_size) nth *= 2; if (nth>max_block_size) nth = max_block_size; - const sycl::range<3> block_dims(1, 1, nth); - const sycl::range<3> block_nums(1, 1, nrows_x); - const size_t n_val_tmp = nth / WARP_SIZE; - const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp); + const dpct::dim3 block_dims(nth, 1, 1); + const dpct::dim3 block_nums(params.ne01, params.ne02, params.ne03); + const size_t nbytes_shared = + (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE) * sizeof(float); - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const int id = get_current_device_id(); + const size_t smpbo = ggml_sycl_info().devices[id].smpbo; - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - const size_t local_mem_size = stream->get_device().get_info(); - if (n_local_scratch*sizeof(float) < local_mem_size) { - if (ncols_x > max_block_size) { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - return; - } - switch (ncols_x) { - case 32: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 64: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 128: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 256: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 512: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 1024: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 2048: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 4096: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - default: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - } + if (nbytes_shared <= smpbo) { + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>( + x, mask, sinks, dst, params, stream, block_dims, block_nums, + nbytes_shared); } else { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, WARP_SIZE, stream); + const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared_low), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + soft_max_f32( + x, mask, sinks, dst, params, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); + }); } } +static void soft_max_back_f32_sycl(const float * grad, + const float * dstf, + float * dst, + const int ncols, + const int nrows, + const float scale, + dpct::queue_ptr stream) { + const dpct::dim3 block_dims(WARP_SIZE, 1, 1); + const dpct::dim3 block_nums(nrows, 1, 1); + + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + soft_max_back_f32(grad, dstf, dst, ncols, scale); + GGML_UNUSED(item_ct1); + }); +} + void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + const float * src0_d = (const float *) src0->data; + const void * src1_d = src1 ? (const void *) src1->data : nullptr; + const void * src2_d = src2 ? (const void *) src2->data : nullptr; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional + // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t nrows_x = ggml_nrows(dst->src[0]); - const int64_t nrows_y = dst->src[0]->ne[1]; + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; - float scale = 1.0f; + const int64_t ne00 = src0->ne[0]; + + float scale = 1.0f; float max_bias = 0.0f; - memcpy(&scale, dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, dst->op_params + 1, sizeof(float)); + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; - ggml_sycl_set_device(ctx.device); - dpct::queue_ptr main_stream = ctx.stream(); + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; - if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { - const sycl::half * src1_dd = static_cast(dst->src[1]->data); - soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, - main_stream, ctx.device); - } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { - const float * src1_dd = static_cast(dst->src[1]->data); - soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + + soft_max_params params = {}; + params.nheads = src0->ne[2]; + params.n_head_log2 = n_head_log2; + params.ncols = ne00; + params.nrows_x = nrows_x; + params.nrows_y = nrows_y; + params.ne00 = src0->ne[0]; + params.ne01 = src0->ne[1]; + params.ne02 = src0->ne[2]; + params.ne03 = src0->ne[3]; + params.nb11 = nb11; + params.nb12 = nb12; + params.nb13 = nb13; + params.ne12 = ne12; + params.ne13 = ne13; + params.scale = scale; + params.max_bias = max_bias; + params.m0 = m0; + params.m1 = m1; + + if (use_f16) { + soft_max_f32_sycl(src0_d, (const sycl::half *)src1_d, + (const float *)src2_d, dst_d, params, stream, + ctx.device); } else { - /* mask unavailable */ - soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + soft_max_f32_sycl(src0_d, (const float *)src1_d, (const float *)src2_d, + dst_d, params, stream, ctx.device); } } + +void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor * src0 = dst->src[0]; // grad + const ggml_tensor * src1 = dst->src[1]; // forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + + soft_max_back_f32_sycl(src0_d, src1_d, dst_d, ncols, nrows, scale, stream); +} diff --git a/ggml/src/ggml-sycl/softmax.hpp b/ggml/src/ggml-sycl/softmax.hpp index 2cf8582ec92e9..23f1e5a9d65e6 100644 --- a/ggml/src/ggml-sycl/softmax.hpp +++ b/ggml/src/ggml-sycl/softmax.hpp @@ -15,6 +15,10 @@ #include "common.hpp" +#define SYCL_SOFT_MAX_BLOCK_SIZE 1024 + void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst); +void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + #endif // GGML_SYCL_SOFTMAX_HPP diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9c99b90faace8..f5e5fba8008bd 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -128,6 +128,8 @@ class LLM: ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx" ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs" EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input" + DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in" + DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -433,6 +435,8 @@ class MODEL_TENSOR(IntEnum): TOKEN_TYPES = auto() POS_EMBD = auto() OUTPUT = auto() + DENSE_2_OUT = auto() # embeddinggemma 2_Dense + DENSE_3_OUT = auto() # embeddinggemma 3_Dense OUTPUT_NORM = auto() ROPE_FREQS = auto() ROPE_FACTORS_LONG = auto() @@ -777,6 +781,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POS_EMBD: "position_embd", MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense + MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense MODEL_TENSOR.ROPE_FREQS: "rope_freqs", MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", @@ -1759,6 +1765,8 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA_EMBEDDING: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.DENSE_2_OUT, + MODEL_TENSOR.DENSE_3_OUT, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index dfe4bfd490519..306679e21834b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -730,6 +730,10 @@ def add_shared_kv_layers(self, value: int) -> None: def add_sliding_window_pattern(self, value: Sequence[bool]) -> None: self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value) + def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None: + self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f) + self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f) + def add_logit_scale(self, value: float) -> None: self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 3e9a2dd8f8cc9..c05aa6cc488de 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -76,7 +76,12 @@ class TensorNameMap: "lm_head", # llama4 "model.transformer.ff_out", # llada ), - + MODEL_TENSOR.DENSE_2_OUT: ( + "dense_2_out", # embeddinggemma + ), + MODEL_TENSOR.DENSE_3_OUT: ( + "dense_3_out", # embeddinggemma + ), # Output norm MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt index 56b6752ac0645..6c6bea9490b4b 100644 --- a/requirements/requirements-all.txt +++ b/requirements/requirements-all.txt @@ -14,3 +14,5 @@ -r ./requirements-tool_bench.txt -r ./requirements-gguf_editor_gui.txt + +-r ../examples/model-conversion/requirements.txt diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 45f0d0e2cbbd4..869e4dccf0dc9 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -219,6 +219,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, + // sentence-transformers dense modules feature dims + { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -1071,6 +1076,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, + { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -2281,6 +2288,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 507fe5f3793e0..c3ae71655b17b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -271,6 +271,12 @@ enum llm_kv { LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, + + // sentence-transformers dense layers in and out features + LLM_KV_DENSE_2_FEAT_IN, + LLM_KV_DENSE_2_FEAT_OUT, + LLM_KV_DENSE_3_FEAT_IN, + LLM_KV_DENSE_3_FEAT_OUT, }; enum llm_tensor { @@ -278,6 +284,8 @@ enum llm_tensor { LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, + LLM_TENSOR_DENSE_2_OUT, + LLM_TENSOR_DENSE_3_OUT, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d8a8b5e647a85..e7526e7d0a557 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2346,6 +2346,12 @@ llama_context * llama_init_from_model( return nullptr; } + if (params.pooling_type != model->hparams.pooling_type) { + //user-specified pooling-type is different from the model default + LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, + model->hparams.pooling_type, params.pooling_type); + } + try { auto * ctx = new llama_context(*model, params); return ctx; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 90cd885a60a4f..a24853c63ada4 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1853,6 +1853,23 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +void llm_graph_context::build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const { + if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) { + return; + } + ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; + GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd"); + + cur = ggml_mul_mat(ctx0, dense_2, cur); + cur = ggml_mul_mat(ctx0, dense_3, cur); + cb(cur, "result_embd_pooled", -1); + res->t_embd_pooled = cur; + ggml_build_forward_expand(gf, cur); +} + + void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, diff --git a/src/llama-graph.h b/src/llama-graph.h index 34b984afeb043..dc84b7942893a 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -814,6 +814,14 @@ struct llm_graph_context { ggml_tensor * cls_b, ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + + // + // dense (out) + // + + void build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const; }; // TODO: better name diff --git a/src/llama-hparams.h b/src/llama-hparams.h index f29b23eeffe56..4e7f73ec234c3 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -169,6 +169,12 @@ struct llama_hparams { uint32_t laurel_rank = 64; uint32_t n_embd_altup = 256; + // needed for sentence-transformers dense layers + uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense + uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense + uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense + uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense + // xIELU std::array xielu_alpha_n; std::array xielu_alpha_p; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 816f2d5de592b..736693e174527 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -123,11 +123,8 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k; - ggml_tensor * v; - - k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 03c2f49d78267..a5fe5b749c355 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1218,12 +1218,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(6); hparams.causal_attn = false; // embeddings do not use causal attention - hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + //applied only if model converted with --sentence-transformers-dense-modules + ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); + ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + + GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); + GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); switch (hparams.n_layer) { case 24: type = LLM_TYPE_0_3B; break; @@ -3686,6 +3695,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -19893,6 +19907,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // if the gguf model was converted with --sentence-transformers-dense-modules + // there will be two additional dense projection layers + // dense linear projections are applied after pooling + // TODO: move reranking logic here and generalize + llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + return llm->res->get_gf(); } diff --git a/src/llama-model.h b/src/llama-model.h index 20b59d952bf90..7f48662f2807a 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -438,6 +438,12 @@ struct llama_model { std::vector layers; + //Dense linear projections for SentenceTransformers models like embeddinggemma + // For Sentence Transformers models structure see + // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; + llama_model_params params; // gguf metadata diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 8d57b4a16772a..550df72e93c33 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index de6e1a322b2c2..41ecb279feb89 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -9,7 +9,6 @@ #include "sampling.h" #include "speculative.h" #include "mtmd.h" -#include "mtmd-helper.h" // mime type for sending response #define MIMETYPE_JSON "application/json; charset=utf-8" @@ -158,7 +157,6 @@ struct slot_params { if (only_metrics) { return json { - {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, {"temperature", sampling.temp}, {"dynatemp_range", sampling.dynatemp_range}, @@ -181,7 +179,8 @@ struct slot_params { {"mirostat", sampling.mirostat}, {"mirostat_tau", sampling.mirostat_tau}, {"mirostat_eta", sampling.mirostat_eta}, - {"max_tokens", n_predict}, // User configured n_predict + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? {"n_keep", n_keep}, {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, @@ -209,7 +208,6 @@ struct slot_params { } return json { - {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, {"temperature", sampling.temp}, {"dynatemp_range", sampling.dynatemp_range}, @@ -234,7 +232,8 @@ struct slot_params { {"mirostat_tau", sampling.mirostat_tau}, {"mirostat_eta", sampling.mirostat_eta}, {"stop", antiprompt}, - {"max_tokens", n_predict}, // User configured n_predict + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? {"n_keep", n_keep}, {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, @@ -265,15 +264,15 @@ struct server_task { int id = -1; // to be filled by server_queue int index = -1; // used when there are multiple prompts (batch request) - server_task_type type; - // used by SERVER_TASK_TYPE_CANCEL int id_target = -1; + int id_slot = -1; // used by SERVER_TASK_TYPE_INFERENCE slot_params params; - server_tokens prompt_tokens; - int id_selected_slot = -1; + server_tokens tokens; + + server_task_type type; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE struct slot_action { @@ -289,6 +288,8 @@ struct server_task { // used by SERVER_TASK_TYPE_SET_LORA std::vector set_lora; + server_task() = default; + server_task(server_task_type type) : type(type) {} static slot_params params_from_json_cmpl( @@ -305,6 +306,7 @@ struct server_task { defaults.sampling = params_base.sampling; defaults.speculative = params_base.speculative; defaults.n_keep = params_base.n_keep; + defaults.n_predict = params_base.n_predict; defaults.antiprompt = params_base.antiprompt; // enabling this will output extra debug information in the HTTP responses from the server @@ -323,32 +325,32 @@ struct server_task { params.n_discard = json_value(data, "n_discard", defaults.n_discard); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.response_fields = json_value(data, "response_fields", std::vector()); - - params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); - params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); @@ -690,7 +692,7 @@ struct server_task_result { // using shared_ptr for polymorphism of server_task_result using server_task_result_ptr = std::unique_ptr; -inline std::string stop_type_to_str(stop_type type) { +static inline std::string stop_type_to_str(stop_type type) { switch (type) { case STOP_TYPE_EOS: return "eos"; case STOP_TYPE_WORD: return "word"; @@ -764,13 +766,6 @@ struct completion_token_output { } }; -struct ctx_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - std::vector data; -}; - struct server_task_result_cmpl_final : server_task_result { int index = 0; @@ -797,11 +792,12 @@ struct server_task_result_cmpl_final : server_task_result { slot_params generation_params; // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_msg; + std::vector oaicompat_msg_diffs; virtual int get_index() override { @@ -1373,17 +1369,17 @@ struct server_task_result_slot_save_load : server_task_result { { "save_ms", t_ms } }}, }; - } else { - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", n_tokens }, - { "n_read", n_bytes }, - { "timings", { - { "restore_ms", t_ms } - }}, - }; } + + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; } }; @@ -1404,15 +1400,218 @@ struct server_task_result_apply_lora : server_task_result { } }; +struct server_prompt_checkpoint { + llama_pos pos_min; + llama_pos pos_max; + + std::vector data; + + size_t size() const { + return data.size(); + } +}; + +struct server_prompt { + server_tokens tokens; + + std::vector data; + + std::list checkpoints; + + size_t size() const { + size_t res = data.size(); + + for (const auto & checkpoint : checkpoints) { + res += checkpoint.size(); + } + + return res; + } + + int n_tokens() const { + return tokens.size(); + } +}; + +struct server_prompt_cache { + server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { + this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); + this->limit_tokens = limit_tokens; + } + + std::list states; + + // in bytes, 0 = no limit + size_t limit_size = 0; + + // in tokens, 0 = no limit + size_t limit_tokens = 0; + + size_t size() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.size(); + } + + return res; + } + + size_t n_tokens() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.n_tokens(); + } + + return res; + } + + server_prompt * alloc(const server_prompt & prompt, size_t state_size) { + // first check if the current state is contained fully in the cache + for (auto it = states.begin(); it != states.end(); ++it) { + const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); + + if (cur_lcp_len == (int) prompt.tokens.size()) { + SRV_WRN("%s", " - prompt is already in the cache, skipping\n"); + return nullptr; + } + } + + // next, remove any cached prompts that are fully contained in the current prompt + for (auto it = states.begin(); it != states.end();) { + const int len = it->tokens.get_common_prefix(prompt.tokens); + + if (len == (int) it->tokens.size()) { + SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); + + it = states.erase(it); + } else { + ++it; + } + } + + std::vector state_data; + + // check if we can allocate enough memory for the new state + try { + state_data.resize(state_size); + } catch (const std::bad_alloc & e) { + SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); + + limit_size = std::max(1, 0.4*size()); + + SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); + + update(); + + return nullptr; + } + + // TODO: for some reason we can't copy server_tokens, so we have to do this workaround + auto & cur = states.emplace_back(); + cur = { + /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), + /*.data =*/ std::move(state_data), + /*.checkpoints =*/ prompt.checkpoints, + }; + + return &cur; + } + + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { + const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); + + float f_keep_best = float(lcp_best) / prompt.tokens.size(); + float sim_best = float(lcp_best) / tokens_new.size(); + + SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + auto it_best = states.end(); + + // find the most similar cached prompt, that would also preserve the most context + for (auto it = states.begin(); it != states.end(); ++it) { + const int lcp_cur = it->tokens.get_common_prefix(tokens_new); + + const float f_keep_cur = float(lcp_cur) / it->tokens.size(); + const float sim_cur = float(lcp_cur) / tokens_new.size(); + + // don't trash large prompts + if (f_keep_cur < 0.25f) { + continue; + } + + if (f_keep_best < f_keep_cur && sim_best < sim_cur) { + f_keep_best = f_keep_cur; + sim_best = sim_cur; + + it_best = it; + } + } + + if (it_best != states.end()) { + SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + const size_t size = it_best->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + it_best->data.clear(); + it_best->data.shrink_to_fit(); + + prompt = std::move(*it_best); + + states.erase(it_best); + } + + return true; + } + + void update() { + if (limit_size > 0) { + // always keep at least one state, regardless of the limits + while (states.size() > 1 && size() > limit_size) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + if (limit_tokens > 0) { + while (states.size() > 1 && n_tokens() > limit_tokens) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens)\n", + states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens); + + for (const auto & state : states) { + SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + } + } +}; + struct server_slot { int id; - int id_task = -1; - - // only used for completion/embedding/infill/rerank - server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; llama_batch batch_spec = {}; + // TODO: change to unique_ptrs for consistency: llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1421,15 +1620,8 @@ struct server_slot { common_speculative * spec = nullptr; - std::vector lora; - int32_t alora_invocation_start = -1; - - // the index relative to completion multi-task request - size_t index = 0; - - struct slot_params params; - - slot_state state = SLOT_STATE_IDLE; + std::unique_ptr task; + std::unique_ptr task_prev; // used for debugging // used to determine the slot that has been used the longest int64_t t_last_used = -1; @@ -1437,38 +1629,66 @@ struct server_slot { // generation props int32_t n_ctx = 0; // context size per slot int32_t n_past = 0; + int32_t n_keep = 0; int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t i_batch = -1; - int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated - int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; - // input prompt tokens - server_tokens prompt_tokens; + int32_t n_prompt_tokens() const { + return task->tokens.size(); + } size_t last_nl_pos = 0; std::string generated_text; llama_tokens generated_tokens; - common_chat_msg chat_msg; - server_tokens cache_tokens; + common_chat_msg chat_msg; std::vector generated_token_probs; - std::vector ctx_checkpoints; - bool has_next_token = true; bool has_new_line = false; bool truncated = false; + stop_type stop; std::string stopping_word; + // state + slot_state state = SLOT_STATE_IDLE; + + server_prompt prompt; + + void prompt_save(server_prompt_cache & prompt_cache) const { + assert(prompt.data.size() == 0); + + const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); + + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size); + if (cur == nullptr) { + return; + } + + llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + } + + void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { + bool res = prompt_cache.load(prompt, tokens, ctx, id); + if (!res) { + SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); + } + } + + std::vector lora; + int32_t alora_invocation_start = -1; + // sampling json json_schema; @@ -1480,7 +1700,7 @@ struct server_slot { std::vector generated_tool_call_ids; // stats - size_t n_sent_text = 0; // number of sent text character + size_t n_sent_text = 0; // number of sent text character int64_t t_start_process_prompt; int64_t t_start_generation; @@ -1497,19 +1717,17 @@ struct server_slot { void reset() { SLT_DBG(*this, "%s", "\n"); - n_prompt_tokens = 0; n_prompt_tokens_cache = 0; - last_nl_pos = 0; - generated_text = ""; - has_new_line = false; - truncated = false; - stop = STOP_TYPE_NONE; - stopping_word = ""; - n_past = 0; - n_sent_text = 0; - task_type = SERVER_TASK_TYPE_COMPLETION; - chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; generated_tokens.clear(); generated_token_probs.clear(); @@ -1521,16 +1739,23 @@ struct server_slot { n_draft_total = 0; n_draft_accepted = 0; + task.reset(); + task_prev.reset(); + // clear alora start alora_invocation_start = -1; } bool need_embd() const { - return server_task_type_need_embd(task_type); + GGML_ASSERT(task); + + return server_task_type_need_embd(task->type); } bool need_logits() const { - return server_task_type_need_logits(task_type); + GGML_ASSERT(task); + + return server_task_type_need_logits(task->type); } // if the context does not have a memory module then all embeddings have to be computed within a single ubatch @@ -1542,18 +1767,22 @@ struct server_slot { } bool can_batch_with(server_slot & other_slot) const { - return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora); + GGML_ASSERT(task); + + return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora); } bool has_budget(const common_params & global_params) { - if (params.n_predict == -1 && global_params.n_predict == -1) { + GGML_ASSERT(task); + + if (task->params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } n_remaining = -1; - if (params.n_predict != -1) { - n_remaining = params.n_predict - n_decoded; + if (task->params.n_predict != -1) { + n_remaining = task->params.n_predict - n_decoded; } else if (global_params.n_predict != -1) { n_remaining = global_params.n_predict - n_decoded; } @@ -1566,7 +1795,7 @@ struct server_slot { } bool can_speculate() const { - return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + return ctx_dft; } void add_token(const completion_token_output & token) { @@ -1579,11 +1808,17 @@ struct server_slot { void release() { if (is_processing()) { + GGML_ASSERT(task); + SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); t_last_used = ggml_time_us(); t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; state = SLOT_STATE_IDLE; + + task_prev = std::move(task); + task.reset(); + callback_on_release(id); } } @@ -1592,19 +1827,19 @@ struct server_slot { result_timings timings; timings.cache_n = n_prompt_tokens_cache; - timings.prompt_n = n_prompt_tokens_processed; - timings.prompt_ms = t_prompt_processing; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - timings.predicted_n = n_decoded; - timings.predicted_ms = t_token_generation; + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; timings.predicted_per_token_ms = t_token_generation / n_decoded; - timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; // Add speculative metrics if (n_draft_total > 0) { - timings.draft_n = n_draft_total; + timings.draft_n = n_draft_total; timings.draft_n_accepted = n_draft_accepted; } @@ -1612,14 +1847,16 @@ struct server_slot { } const common_chat_msg & update_chat_msg(std::vector & diffs) { + GGML_ASSERT(task); + auto previous_msg = chat_msg; SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); auto new_msg = common_chat_parse( generated_text, /* is_partial= */ stop != STOP_TYPE_EOS, - params.oaicompat_chat_syntax); + task->params.oaicompat_chat_syntax); if (!new_msg.empty()) { - new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id); + new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); chat_msg = new_msg; diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); } @@ -1627,9 +1864,11 @@ struct server_slot { } size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { + GGML_ASSERT(task); + size_t stop_pos = std::string::npos; - for (const std::string & word : params.antiprompt) { + for (const std::string & word : task->params.antiprompt) { size_t pos; if (is_full_stop) { @@ -1682,43 +1921,36 @@ struct server_slot { } json to_json(bool only_metrics = false) const { - if (only_metrics) { - return json { - {"id", id}, - {"id_task", id_task}, - {"n_ctx", n_ctx}, - {"speculative", can_speculate()}, - {"is_processing", is_processing()}, - {"params", params.to_json(true)}, - {"next_token", - { - {"has_next_token", has_next_token}, - {"has_new_line", has_new_line}, - {"n_remain", n_remaining}, - {"n_decoded", n_decoded}, - } - }, - }; - } + json res; - return json { + res = { {"id", id}, - {"id_task", id_task}, {"n_ctx", n_ctx}, {"speculative", can_speculate()}, {"is_processing", is_processing()}, - {"params", params.to_json()}, - {"prompt", prompt_tokens.detokenize(ctx, true)}, - {"next_token", + }; + + const auto & ptask = task ? task : task_prev; + + if (ptask) { + res["id_task"] = ptask->id; + res["params"] = ptask->params.to_json(only_metrics); + res["next_token"] = { { {"has_next_token", has_next_token}, {"has_new_line", has_new_line}, {"n_remain", n_remaining}, {"n_decoded", n_decoded}, - {"stopping_word", stopping_word}, } - }, - }; + }; + + if (!only_metrics) { + res["prompt"] = ptask->tokens.detokenize(ctx, true); + res["generated"] = generated_text; + } + } + + return res; } }; @@ -2109,11 +2341,14 @@ struct server_context { // slots / clients std::vector slots; - json default_generation_settings_for_props; + + int slots_debug = 0; server_queue queue_tasks; server_response queue_results; + std::unique_ptr prompt_cache; + server_metrics metrics; // Necessary similarity of prompt for slot selection @@ -2268,9 +2503,8 @@ struct server_context { slot.id = i; slot.ctx = ctx; slot.n_ctx = n_ctx_slot; - slot.n_predict = params_base.n_predict; slot.mctx = mctx; - slot.cache_tokens.has_mtmd = mctx != nullptr; + slot.prompt.tokens.has_mtmd = mctx != nullptr; if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); @@ -2286,16 +2520,13 @@ struct server_context { SRV_ERR("%s", "failed to create speculator\n"); return; } - for (auto &pair : params_base.speculative.replacements) { + for (auto & pair : params_base.speculative.replacements) { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); - slot.params.sampling = params_base.sampling; - slot.params.n_keep = params_base.n_keep; - slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); }; @@ -2305,7 +2536,14 @@ struct server_context { slots.push_back(std::move(slot)); } - default_generation_settings_for_props = slots[0].to_json(); + { + const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); + slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; + + if (slots_debug) { + SRV_WRN("slots debug = %d\n", slots_debug); + } + } // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) @@ -2316,11 +2554,25 @@ struct server_context { metrics.init(); + if (params_base.cache_ram_mib != 0) { + if (params_base.cache_ram_mib < 0) { + SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit"); + } else { + SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib); + } + SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n"); + + prompt_cache = std::make_unique(params_base.cache_ram_mib, n_ctx); + } else { + SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); + } + SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); + // thinking is enabled if: // 1. It's not explicitly disabled (reasoning_budget == 0) // 2. The chat template supports it const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); - SRV_INF("Enable thinking? %d\n", enable_thinking); + SRV_INF("thinking = %d\n", enable_thinking); oai_parser_opt = { /* use_jinja */ params_base.use_jinja, @@ -2347,10 +2599,11 @@ struct server_context { server_slot * get_available_slot(const server_task & task) { server_slot * ret = nullptr; + bool update_cache = false; + // find the slot that has at least n% prompt similarity if (ret == nullptr && slot_prompt_similarity != 0.0f) { - int lcs_len = 0; - float similarity = 0; + float sim_best = 0; for (server_slot & slot : slots) { // skip the slot if it is not available @@ -2358,27 +2611,34 @@ struct server_context { continue; } + const auto & tokens = slot.prompt.tokens; + // skip the slot if it does not contains cached tokens - if (slot.cache_tokens.empty()) { + if (tokens.empty()) { continue; } - // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); - - // fraction of the common subsequence length compared to the current slot's prompt length - float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); + // fraction of the Longest Common Prefix length with respect to the input prompt length + const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size(); // select the current slot if the criteria match - if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { - lcs_len = cur_lcs_len; - similarity = cur_similarity; + if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { + sim_best = sim_cur; + ret = &slot; } } if (ret != nullptr) { - SLT_INF(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %.3f (> %.3f thold)\n", lcs_len, similarity, slot_prompt_similarity); + const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); + + SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", + sim_best, slot_prompt_similarity, f_keep); + + // if we are about to lose a large portion of the existing context - save it in the prompt cache + if (f_keep < 0.5f) { + update_cache = true; + } } } @@ -2401,6 +2661,36 @@ struct server_context { if (ret != nullptr) { SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last); + + update_cache = true; + } + } + + if (ret) { + const auto & tokens = ret->prompt.tokens; + + update_cache = update_cache && prompt_cache; + + // cache prompts only for completion tasks + update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; + + // don't update the cache if the slot's context is empty + update_cache = update_cache && tokens.size() > 0; + + // TODO: mtmd does not support prompt cache + update_cache = update_cache && (ret->mctx == nullptr); + + if (update_cache) { + SRV_WRN("%s", "updating prompt cache\n"); + + const int64_t t_start = ggml_time_us(); + + ret->prompt_save(*prompt_cache); + ret->prompt_load(*prompt_cache, task.tokens); + + prompt_cache->update(); + + SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); } } @@ -2409,27 +2699,21 @@ struct server_context { bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); - slot.id_task = task.id; - slot.index = task.index; - slot.task_type = task.type; - slot.params = std::move(task.params); - slot.prompt_tokens = std::move(task.prompt_tokens); - if (!are_lora_equal(slot.params.lora, slot.lora)) { + if (!are_lora_equal(task.params.lora, slot.lora)) { // if lora has changed, check to see if the cache should be cleared - if (lora_should_clear_cache(slot.lora, slot.params.lora)) { - SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size()); - slot.cache_tokens.clear(); + if (lora_should_clear_cache(slot.lora, task.params.lora)) { + SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); + slot.prompt.tokens.clear(); } else { - SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size()); + SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); } - slot.lora = slot.params.lora; + slot.lora = task.params.lora; } // if using alora, make sure it's only a single one requested and active - size_t alora_invocation_start = slot.prompt_tokens.size(); + size_t alora_invocation_start = task.tokens.size(); if (lora_all_alora(slot.lora)) { - const auto & enabled_ids = lora_get_enabled_ids(slot.lora); // TODO: This will error out if a user requests two aloras, but only // provides the activation string for one. We could, instead search @@ -2448,10 +2732,10 @@ struct server_context { // scan backwards through the prompt tokens to find the last // occurrence of the invocation sequence int match_idx = static_cast(n_invocation_tokens) - 1; - for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) { + for (int i = task.tokens.size() - 1; i >= 0; --i) { // the token in this position matches the next token to find in // the invocation sequence - if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) { + if (task.tokens[i] == invocation_tokens[match_idx]) { // if it's a full match, we've found the start if (match_idx == 0) { alora_invocation_start = i; @@ -2466,7 +2750,7 @@ struct server_context { } // if the activation string is not found, disable the alora - if (alora_invocation_start == slot.prompt_tokens.size()) { + if (alora_invocation_start == task.tokens.size()) { SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); slot.lora[enabled_ids[0]].scale = 0.0f; } else { @@ -2475,24 +2759,20 @@ struct server_context { } } - if (!slot.prompt_tokens.validate(ctx)) { + if (!task.tokens.validate(ctx)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } - SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); - if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { - // Might be better to reject the request with a 400 ? - SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict); - slot.params.n_predict = slot.n_predict; - } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); + // initialize samplers { if (slot.smpl != nullptr) { common_sampler_free(slot.smpl); } - slot.smpl = common_sampler_init(model, slot.params.sampling); + slot.smpl = common_sampler_init(model, task.params.sampling); if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -2500,12 +2780,15 @@ struct server_context { } } + // initialize draft batch if (slot.ctx_dft) { llama_batch_free(slot.batch_spec); - slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); } + slot.task = std::make_unique(std::move(task)); + slot.state = SLOT_STATE_STARTED; SLT_INF(slot, "%s", "processing task\n"); @@ -2527,7 +2810,7 @@ struct server_context { slot.sampled = result.tok; slot.generated_text += token_str; - if (slot.params.return_tokens) { + if (slot.task->params.return_tokens) { slot.generated_tokens.push_back(result.tok); } slot.has_next_token = true; @@ -2564,7 +2847,7 @@ struct server_context { } slot.add_token(result); - if (slot.params.stream) { + if (slot.task->params.stream) { send_partial_response(slot, result, false); } } @@ -2586,12 +2869,12 @@ struct server_context { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); } if (slot.has_new_line) { // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent - if (slot.params.n_indent > 0) { + if (slot.task->params.n_indent > 0) { // check the current indentation // TODO: improve by not doing it more than once for each new line if (slot.last_nl_pos > 0) { @@ -2603,7 +2886,7 @@ struct server_context { pos++; } - if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { + if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; @@ -2630,11 +2913,11 @@ struct server_context { slot.has_new_line = true; // if we have seen a new line, we stop after a certain time limit, but only upon another new line - if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms); } } @@ -2645,7 +2928,7 @@ struct server_context { slot.has_next_token = false; SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); + slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); } if (llama_vocab_is_eog(vocab, result.tok)) { @@ -2657,7 +2940,7 @@ struct server_context { const auto n_ctx_train = llama_model_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction @@ -2665,7 +2948,7 @@ struct server_context { SLT_WRN(slot, "n_predict (%d) is set for infinite generation. " "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", - slot.params.n_predict, n_ctx_train); + slot.task->params.n_predict, n_ctx_train); } SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); @@ -2674,7 +2957,7 @@ struct server_context { } void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - size_t n_probs = slot.params.sampling.n_probs; + size_t n_probs = slot.task->params.sampling.n_probs; size_t n_vocab = llama_vocab_n_tokens(vocab); if (post_sampling) { @@ -2728,7 +3011,7 @@ struct server_context { } void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx); + send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx); } void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { @@ -2749,7 +3032,7 @@ struct server_context { } // if multimodal is enabled, send an error and return false - bool ensure_no_mtmd(const int id_task) { + bool check_no_mtmd(const int id_task) { if (mctx) { send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); return false; @@ -2760,14 +3043,14 @@ struct server_context { void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; + res->id = slot.task->id; + res->index = slot.task->index; if (is_progress) { res->is_progress = true; - res->progress.total = slot.n_prompt_tokens; + res->progress.total = slot.n_prompt_tokens(); res->progress.cache = slot.n_prompt_tokens_cache; - res->progress.processed = slot.cache_tokens.size(); + res->progress.processed = slot.prompt.tokens.size(); res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000); } else { res->content = tkn.text_to_send; @@ -2777,21 +3060,21 @@ struct server_context { } res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; - res->post_sampling_probs = slot.params.post_sampling_probs; + res->n_prompt_tokens = slot.n_prompt_tokens(); + res->post_sampling_probs = slot.task->params.post_sampling_probs; - res->verbose = slot.params.verbose; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->verbose = slot.task->params.verbose; + res->oaicompat = slot.task->params.oaicompat; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; // populate res.probs_output - if (slot.params.sampling.n_probs > 0) { + if (slot.task->params.sampling.n_probs > 0) { res->prob_output = tkn; // copy the token probs } // populate timings if this is final response or timings_per_token is enabled - if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { res->timings = slot.get_timings(); } @@ -2800,36 +3083,37 @@ struct server_context { void send_final_response(server_slot & slot) { auto res = std::make_unique(); - res->id = slot.id_task; - res->id_slot = slot.id; - res->index = slot.index; + res->id = slot.task->id; + res->id_slot = slot.id; + + res->index = slot.task->index; res->content = slot.generated_text; res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = slot.prompt_tokens.detokenize(ctx, true); - res->response_fields = std::move(slot.params.response_fields); + res->prompt = slot.task->tokens.detokenize(ctx, true); + res->response_fields = std::move(slot.task->params.response_fields); res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_prompt_tokens = slot.n_prompt_tokens(); res->n_tokens_cached = slot.n_past; res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; res->stop = slot.stop; - res->post_sampling_probs = slot.params.post_sampling_probs; + res->post_sampling_probs = slot.task->params.post_sampling_probs; - res->verbose = slot.params.verbose; - res->stream = slot.params.stream; - res->include_usage = slot.params.include_usage; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); + res->verbose = slot.task->params.verbose; + res->stream = slot.task->params.stream; + res->include_usage = slot.task->params.include_usage; + res->oaicompat = slot.task->params.oaicompat; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; + res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); // populate res.probs_output - if (slot.params.sampling.n_probs > 0) { - if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { + if (slot.task->params.sampling.n_probs > 0) { + if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); @@ -2843,17 +3127,17 @@ struct server_context { } } - res->generation_params = slot.params; // copy the parameters + res->generation_params = slot.task->params; // copy the parameters queue_results.send(std::move(res)); } void send_embedding(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; - res->oaicompat = slot.params.oaicompat; + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.n_prompt_tokens(); + res->oaicompat = slot.task->params.oaicompat; const int n_embd = llama_model_n_embd(model); @@ -2880,12 +3164,12 @@ struct server_context { // normalize only when there is pooling if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, slot.params.embd_normalize); + common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); res->embedding.push_back(embd_res); break; - } else { - res->embedding.emplace_back(embd, embd + n_embd); } + + res->embedding.emplace_back(embd, embd + n_embd); } SLT_DBG(slot, "%s", "sending embeddings\n"); @@ -2895,9 +3179,9 @@ struct server_context { void send_rerank(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.n_prompt_tokens(); for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { @@ -3034,7 +3318,7 @@ struct server_context { case SERVER_TASK_TYPE_EMBEDDING: case SERVER_TASK_TYPE_RERANK: { - const int id_slot = task.id_selected_slot; + const int id_slot = task.id_slot; server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); @@ -3061,7 +3345,7 @@ struct server_context { { // release slot linked with the task id for (auto & slot : slots) { - if (slot.id_task == task.id_target) { + if (slot.task && slot.task->id == task.id_target) { slot.release(); break; } @@ -3079,7 +3363,7 @@ struct server_context { int n_processing_slots = 0; for (server_slot & slot : slots) { - json slot_data = slot.to_json(true); + json slot_data = slot.to_json(slots_debug == 0); if (slot.is_processing()) { n_processing_slots++; @@ -3121,7 +3405,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_SAVE: { - if (!ensure_no_mtmd(task.id)) { + if (!check_no_mtmd(task.id)) { break; } @@ -3138,13 +3422,13 @@ struct server_context { break; } - const size_t token_count = slot->cache_tokens.size(); + const size_t token_count = slot->prompt.tokens.size(); const int64_t t_start = ggml_time_us(); std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - const llama_tokens & tokens = slot->cache_tokens.get_text_tokens(); + const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); @@ -3162,7 +3446,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { - if (!ensure_no_mtmd(task.id)) break; + if (!check_no_mtmd(task.id)) break; int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -3186,13 +3470,13 @@ struct server_context { size_t token_count = 0; size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { - slot->cache_tokens.clear(); // KV may already been invalidated? + slot->prompt.tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } tokens.resize(token_count); - slot->cache_tokens.clear(); - slot->cache_tokens.insert(tokens); + slot->prompt.tokens.clear(); + slot->prompt.tokens.insert(tokens); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -3209,7 +3493,9 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_ERASE: { - if (!ensure_no_mtmd(task.id)) break; + if (!check_no_mtmd(task.id)) { + break; + } int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -3224,9 +3510,9 @@ struct server_context { } // Erase token cache - const size_t n_erased = slot->cache_tokens.size(); + const size_t n_erased = slot->prompt.tokens.size(); llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1); - slot->cache_tokens.clear(); + slot->prompt.tokens.clear(); auto res = std::make_unique(); res->id = task.id; @@ -3282,8 +3568,8 @@ struct server_context { if (!params_base.ctx_shift) { // this check is redundant (for good) // we should never get here, because generation should already stopped in process_token() - slot.release(); send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + slot.release(); continue; } @@ -3294,9 +3580,16 @@ struct server_context { } // Shift context - const int n_keep = slot.params.n_keep + add_bos_token; + int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep; + + if (add_bos_token) { + n_keep += 1; + } + + n_keep = std::min(slot.n_ctx - 4, n_keep); + const int n_left = slot.n_past - n_keep; - const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); @@ -3305,14 +3598,14 @@ struct server_context { // add generated tokens to cache { - llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy + llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; } - new_tokens.resize(slot.cache_tokens.size() - n_discard); - slot.cache_tokens.clear(); - slot.cache_tokens.insert(new_tokens); + new_tokens.resize(slot.prompt.tokens.size() - n_discard); + slot.prompt.tokens.clear(); + slot.prompt.tokens.insert(new_tokens); } slot.n_past -= n_discard; @@ -3328,7 +3621,8 @@ struct server_context { server_slot * slot_batched = nullptr; auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end(); + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); }; // frist, add sampled tokens from any ongoing sequences @@ -3349,10 +3643,10 @@ struct server_context { common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; - slot.cache_tokens.push_back(slot.sampled); + slot.prompt.tokens.push_back(slot.sampled); SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); + slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated); } // process in chunks of params.n_batch @@ -3375,7 +3669,7 @@ struct server_context { // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - auto & prompt_tokens = slot.prompt_tokens; + const auto & input_tokens = slot.task->tokens; // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { @@ -3383,104 +3677,64 @@ struct server_context { slot.t_start_generation = 0; slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); slot.state = SLOT_STATE_PROCESSING_PROMPT; - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", + slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens()); // print prompt tokens (for debugging) /*if (1) { // first 16 tokens (avoid flooding logs) - for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + for (int i = 0; i < std::min(16, input_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); } } else { // all - for (int i = 0; i < (int) prompt_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + for (int i = 0; i < (int) input_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); } }*/ // empty prompt passed -> release the slot and send empty response - if (prompt_tokens.empty()) { + if (input_tokens.empty()) { SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); - slot.release(); slot.print_timings(); send_final_response(slot); + slot.release(); + continue; } // TODO: support memory-less logits computation if (slot.need_logits() && !llama_get_memory(ctx)) { - slot.release(); send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); + slot.release(); continue; } if (!slot.can_split()) { - if (slot.n_prompt_tokens > n_ubatch) { - slot.release(); + if (slot.n_prompt_tokens() > n_ubatch) { send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + slot.release(); continue; } - if (slot.n_prompt_tokens > slot.n_ctx) { - slot.release(); + if (slot.n_prompt_tokens() > slot.n_ctx) { send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); continue; } } else { - if (!params_base.ctx_shift) { - // if context shift is disabled, we make sure prompt size is smaller than KV size - // TODO: there should be a separate parameter that control prompt truncation - // context shift should be applied only during the generation phase - if (slot.n_prompt_tokens >= slot.n_ctx) { - slot.release(); - send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - continue; - } - } - if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.n_prompt_tokens; - } - slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - - // if input prompt is too big, truncate it - if (slot.n_prompt_tokens >= slot.n_ctx) { - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - const int n_left = slot.n_ctx - slot.params.n_keep; - - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - - const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); - llama_tokens new_tokens( - curr_tokens.begin(), - curr_tokens.begin() + slot.params.n_keep); - - new_tokens.insert( - new_tokens.end(), - curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - curr_tokens.end()); - - prompt_tokens.clear(); - prompt_tokens.insert(new_tokens); - - slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); - - SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); - - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + if (slot.n_prompt_tokens() >= slot.n_ctx) { + send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); + continue; } - if (slot.params.cache_prompt) { + if (slot.task->params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens); // if there is an alora invoked, don't cache after the invocation start if (slot.alora_invocation_start >= 0) { @@ -3500,13 +3754,13 @@ struct server_context { SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - while (head_c < slot.cache_tokens.size() && - head_p < prompt_tokens.size()) { + while (head_c < slot.prompt.tokens.size() && + head_p < input_tokens.size()) { size_t n_match = 0; - while (head_c + n_match < slot.cache_tokens.size() && - head_p + n_match < prompt_tokens.size() && - slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + while (head_c + n_match < slot.prompt.tokens.size() && + head_p + n_match < input_tokens.size() && + slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { n_match++; } @@ -3523,7 +3777,7 @@ struct server_context { llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); + slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); slot.n_past++; } @@ -3547,41 +3801,83 @@ struct server_context { // the largest pos_min required for a checkpoint to be useful const auto pos_min_thold = std::max(0, slot.n_past - n_swa); - if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { + if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); + SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); } + // when the prompt prefix does not match, print the tokens around the mismatch + // this is useful for debugging prompt caching + { + const int np0 = std::max(slot.n_past - 4, 0); + const int np1 = std::min(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); + + std::stringstream ss0; + std::stringstream ss1; + + std::stringstream st0; + std::stringstream st1; + + ss0 << "old: ... "; + ss1 << "new: ... "; + + for (int i = np0; i < np1; i++) { + if (i == slot.n_past) { + ss0 << " | "; + ss1 << " | "; + } + + { + const auto token = slot.prompt.tokens[i]; + const auto piece = common_token_to_piece(ctx, token); + ss0 << piece; + st0 << std::setw(8) << token; + } + + { + const auto token = slot.task->tokens[i]; + const auto piece = common_token_to_piece(ctx, token); + ss1 << piece; + st1 << std::setw(8) << token; + } + } + + SLT_WRN(slot, "%s\n", ss0.str().c_str()); + SLT_WRN(slot, "%s\n", ss1.str().c_str()); + + SLT_WRN(slot, "%s\n", st0.str().c_str()); + SLT_WRN(slot, "%s\n", st1.str().c_str()); + } + if (pos_min > pos_min_thold) { - SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); + SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); // search for a context checkpoint const auto it = std::find_if( - slot.ctx_checkpoints.rbegin(), - slot.ctx_checkpoints.rend(), + slot.prompt.checkpoints.rbegin(), + slot.prompt.checkpoints.rend(), [&](const auto & cur) { // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] return cur.pos_min < pos_min_thold; } ); - bool do_reset = it == slot.ctx_checkpoints.rend(); - //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false"); + bool do_reset = it == slot.prompt.checkpoints.rend(); if (!do_reset) { // restore the context checkpoint - const size_t ctx_checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + const size_t checkpoint_size = it->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != ctx_checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); + if (n != checkpoint_size) { + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); do_reset = true; //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); } } @@ -3595,19 +3891,21 @@ struct server_context { { // erase any checkpoints with pos_min > pos_min_thold - for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) { - const auto & cur = slot.ctx_checkpoints[i]; + for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { + const auto & cur = *it; if (cur.pos_min > pos_min_thold) { SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); - slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i); + it = slot.prompt.checkpoints.erase(it); + } else { + ++it; } } } } // [TAG_PROMPT_LOGITS] - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens); + if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) { + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens()); slot.n_past--; SLT_WRN(slot, "n_past was set to %d\n", slot.n_past); } @@ -3618,7 +3916,7 @@ struct server_context { if (!slot.can_split()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) { continue; } } @@ -3636,28 +3934,28 @@ struct server_context { SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past); // remove the non-common part from the cache - slot.cache_tokens.keep_first(slot.n_past); + slot.prompt.tokens.keep_first(slot.n_past); // check if we should process the image - if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { // process the image int32_t new_n_past; - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); - int32_t n_pos = new_n_past - slot.n_past; - + int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); - slot.release(); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + slot.release(); continue; } // add the image chunk to cache { - const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); - slot.cache_tokens.push_back(chunk.get()); // copy + const auto & chunk = input_tokens.find_chunk(slot.n_past); + slot.prompt.tokens.push_back(chunk.get()); // copy } + const int32_t n_pos = new_n_past - slot.n_past; + slot.n_past += n_pos; slot.n_prompt_tokens_processed += n_pos; } @@ -3678,6 +3976,9 @@ struct server_context { bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + // make checkpoints only for completion tasks + do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; + // make a checkpoint of the parts of the memory that cannot be rolled back. // checkpoints are created only if: // - the model uses SWA and we are not using `swa_full` @@ -3691,9 +3992,9 @@ struct server_context { ); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) { // get next token to process - llama_token cur_tok = slot.prompt_tokens[slot.n_past]; + llama_token cur_tok = input_tokens[slot.n_past]; if (cur_tok == LLAMA_TOKEN_NULL) { break; // end of text chunk } @@ -3707,36 +4008,33 @@ struct server_context { } // embedding requires all tokens in the batch to be output - const bool need_embd = server_task_type_need_embd(slot.task_type); - - common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); - slot.cache_tokens.push_back(cur_tok); + common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd()); + slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; slot.n_past++; // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) { + if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) { break; } } // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens()); // entire prompt has been processed - if (slot.n_past == slot.n_prompt_tokens) { + if (slot.n_past == slot.n_prompt_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); - GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system - for (int i = 0; i < slot.n_prompt_tokens; ++i) { - llama_token id = slot.prompt_tokens[i]; + for (int i = 0; i < slot.n_prompt_tokens(); ++i) { + llama_token id = input_tokens[i]; if (id != LLAMA_TOKEN_NULL) { common_sampler_accept(slot.smpl, id, false); } @@ -3757,21 +4055,22 @@ struct server_context { do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64); + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); if (do_checkpoint) { - while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { // make room for the new checkpoint, if needed - const auto & cur = slot.ctx_checkpoints.front(); + const auto & cur = slot.prompt.checkpoints.front(); + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ /*.pos_min = */ pos_min, /*.pos_max = */ pos_max, /*.data = */ std::vector(checkpoint_size), @@ -3779,8 +4078,8 @@ struct server_context { llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); } } } @@ -3854,8 +4153,8 @@ struct server_context { if (!err.empty()) { SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); for (auto & slot : slots) { - slot.release(); send_error(slot, err); + slot.release(); } break; } @@ -3878,7 +4177,7 @@ struct server_context { for (auto & slot : slots) { // optionally send prompt processing progress if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.params.stream && slot.params.return_progress) { + if (slot.task->params.stream && slot.task->params.return_progress) { send_partial_response(slot, {}, true); } } @@ -3888,7 +4187,7 @@ struct server_context { } if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding send_embedding(slot, batch_view); slot.release(); @@ -3896,7 +4195,7 @@ struct server_context { continue; // continue loop of slots } - if (slot.task_type == SERVER_TASK_TYPE_RERANK) { + if (slot.task->type == SERVER_TASK_TYPE_RERANK) { send_rerank(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -3934,16 +4233,17 @@ struct server_context { result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - if (slot.params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); + if (slot.task->params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); } if (!process_token(result, slot)) { // release slot because of stop condition - slot.release(); slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + slot.release(); + continue; } } @@ -3964,7 +4264,7 @@ struct server_context { } // determine the max draft that fits the current slot state - int n_draft_max = slot.params.speculative.n_max; + int n_draft_max = slot.task->params.speculative.n_max; // note: n_past is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -3976,8 +4276,8 @@ struct server_context { SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - if (n_draft_max < slot.params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); + if (n_draft_max < slot.task->params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); continue; } @@ -3985,16 +4285,16 @@ struct server_context { llama_token id = slot.sampled; struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; + params_spec.p_min = slot.task->params.speculative.p_min; - const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); + const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // ignore small drafts - if (slot.params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + if (slot.task->params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); continue; } @@ -4023,8 +4323,8 @@ struct server_context { // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - slot.cache_tokens.push_back(id); - slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); + slot.prompt.tokens.push_back(id); + slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); @@ -4038,11 +4338,11 @@ struct server_context { // TODO: set result.probs if (!process_token(result, slot)) { - // release slot because of stop condition - slot.release(); slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + slot.release(); + break; } } @@ -4310,18 +4610,18 @@ int main(int argc, char ** argv) { } // TODO: get rid of this dynamic_cast - auto res_metrics = dynamic_cast(result.get()); - GGML_ASSERT(res_metrics != nullptr); + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); // optionally return "fail_on_no_slot" error if (req.has_param("fail_on_no_slot")) { - if (res_metrics->n_idle_slots == 0) { + if (res_task->n_idle_slots == 0) { res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, res_metrics->slots_data); + res_ok(res, res_task->slots_data); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { @@ -4349,56 +4649,56 @@ int main(int argc, char ** argv) { } // TODO: get rid of this dynamic_cast - auto res_metrics = dynamic_cast(result.get()); - GGML_ASSERT(res_metrics != nullptr); + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names json all_metrics_def = json { {"counter", {{ {"name", "prompt_tokens_total"}, {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total} + {"value", (uint64_t) res_task->n_prompt_tokens_processed_total} }, { {"name", "prompt_seconds_total"}, {"help", "Prompt process time"}, - {"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3} + {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3} }, { {"name", "tokens_predicted_total"}, {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) res_metrics->n_tokens_predicted_total} + {"value", (uint64_t) res_task->n_tokens_predicted_total} }, { {"name", "tokens_predicted_seconds_total"}, {"help", "Predict process time"}, - {"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3} + {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3} }, { {"name", "n_decode_total"}, {"help", "Total number of llama_decode() calls"}, - {"value", res_metrics->n_decode_total} + {"value", res_task->n_decode_total} }, { {"name", "n_past_max"}, {"help", "Largest observed n_past."}, - {"value", res_metrics->n_past_max} + {"value", res_task->n_past_max} }, { {"name", "n_busy_slots_per_decode"}, {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_metrics->n_busy_slots_total / std::max((float) res_metrics->n_decode_total, 1.f)} + {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, {"help", "Average prompt throughput in tokens/s."}, - {"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.} + {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} },{ {"name", "predicted_tokens_seconds"}, {"help", "Average generation throughput in tokens/s."}, - {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.} + {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} },{ {"name", "requests_processing"}, {"help", "Number of requests processing."}, - {"value", (uint64_t) res_metrics->n_processing_slots} + {"value", (uint64_t) res_task->n_processing_slots} },{ {"name", "requests_deferred"}, {"help", "Number of requests deferred."}, - {"value", (uint64_t) res_metrics->n_tasks_deferred} + {"value", (uint64_t) res_task->n_tasks_deferred} }}} }; @@ -4419,7 +4719,7 @@ int main(int argc, char ** argv) { } } - res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start)); + res.set_header("Process-Start-Time-Unix", std::to_string(res_task->t_start)); res.set_content(prometheus.str(), "text/plain; version=0.0.4"); res.status = 200; // HTTP OK @@ -4543,9 +4843,22 @@ int main(int argc, char ** argv) { }; const auto handle_props = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + json default_generation_settings_for_props; + + { + slot_params params; + + params.sampling = ctx_server.params_base.sampling; + + default_generation_settings_for_props = json { + {"params", params.to_json(true)}, + {"n_ctx", ctx_server.slots[0].n_ctx}, + }; + } + // this endpoint is publicly available, please only return what is safe to be exposed json data = { - { "default_generation_settings", ctx_server.default_generation_settings_for_props }, + { "default_generation_settings", default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model.path }, { "modalities", json { @@ -4650,12 +4963,12 @@ int main(int argc, char ** argv) { task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( + task.tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, data); - task.id_selected_slot = json_value(data, "id_slot", -1); + task.id_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.oaicompat = oaicompat; @@ -5024,9 +5337,9 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tokenized_prompts[i]); // OAI-compat task.params.oaicompat = oaicompat; @@ -5122,10 +5435,10 @@ int main(int argc, char ** argv) { tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); - server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.prompt_tokens = std::move(tmp); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tmp); tasks.push_back(std::move(task)); } @@ -5383,7 +5696,7 @@ int main(int argc, char ** argv) { #endif LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__, - is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : + is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str()); // this call blocks the main thread until queue_tasks.terminate() is called diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index 829af2ebe7bfb..720b136b05175 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -66,8 +66,7 @@ def test_server_slots(): assert len(res.body) == server.n_slots assert server.n_ctx is not None and server.n_slots is not None assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots - assert "params" in res.body[0] - assert res.body[0]["params"]["seed"] == server.seed + assert "params" not in res.body[0] def test_load_split_model(): diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 2979ed4bb7b12..6e5a3488e789b 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -19,8 +19,8 @@ def create_server(): (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'), (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None), (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None), (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None), ] @@ -54,7 +54,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), - ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), + ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"), ] ) def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 11483e679a505..00ba78cf67c09 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -16,7 +16,7 @@ def create_server(): @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [ ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False), - ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), + ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), ]) def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool): global server @@ -41,7 +41,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), - ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), ]) def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): global server diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index 92e49f2bb05a4..4adbbde64f594 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -4,6 +4,12 @@ server = ServerPreset.tinyllama2() +SHORT_TEXT = """ +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. +Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. +Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. +""".strip() + LONG_TEXT = """ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. @@ -21,19 +27,18 @@ def create_server(): def test_ctx_shift_enabled(): - # the prompt is 301 tokens + # the prompt is 226 tokens # the slot context is 512/2 = 256 tokens - # the prompt is truncated to keep the last (301 - 256/2) = 173 tokens # 96 tokens are generated thanks to shifting the context when it gets full global server server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "n_predict": 96, - "prompt": LONG_TEXT, + "prompt": SHORT_TEXT, }) assert res.status_code == 200 - assert res.body["timings"]["prompt_n"] == 173 + assert res.body["timings"]["prompt_n"] == 226 assert res.body["timings"]["predicted_n"] == 96 assert res.body["truncated"] is True diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 4ca1423aaf2d4..f175115f4fd6a 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -31,10 +31,10 @@ using json = nlohmann::ordered_json; -#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) @@ -1102,6 +1102,7 @@ struct server_tokens { ~server_tokens() = default; // Prevent copying + // TODO: server_tokens should be copyable - remove this: server_tokens(const server_tokens&) = delete; server_tokens& operator=(const server_tokens&) = delete; @@ -1119,7 +1120,7 @@ struct server_tokens { } } - server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} // for debugging std::string str() const { @@ -1144,9 +1145,8 @@ struct server_tokens { auto it = map_pos_to_media.find(pos); if (it != map_pos_to_media.end()) { return it->second; - } else { - throw std::runtime_error("Chunk not found"); } + throw std::runtime_error("Chunk not found"); } void push_back(llama_token tok) { @@ -1170,7 +1170,7 @@ struct server_tokens { map_pos_to_media[start_pos] = std::move(new_chunk); } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { size_t n_tokens; - auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); for (size_t i = 0; i < n_tokens; ++i) { push_back(text_tokens[i]); } @@ -1190,7 +1190,7 @@ struct server_tokens { // We could also just check, but this will prevent silently dropping MTMD data. GGML_ASSERT(has_mtmd); for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) { - auto chunk = tokens.map_pos_to_media[it->first].get(); + auto * chunk = tokens.map_pos_to_media[it->first].get(); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); map_pos_to_media[start_pos+it->first] = std::move(new_chunk); } @@ -1271,33 +1271,52 @@ struct server_tokens { } size_t get_common_prefix(const server_tokens & b) const { - size_t max_idx = std::min(tokens.size(), b.tokens.size()); + const size_t max_idx = std::min(tokens.size(), b.tokens.size()); + + if (!has_mtmd) { + for (size_t i = 0; i < max_idx; ++i) { + if (tokens[i] == b.tokens[i]) { + continue; + } + + return i; + } + + return max_idx; + } + for (size_t i = 0; i < max_idx; ++i) { - auto & ai = tokens[i]; - auto & bi = b.tokens[i]; + const llama_token ai = tokens[i]; + const llama_token bi = b.tokens[i]; if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { - GGML_ASSERT(has_mtmd); const auto & a_chunk = find_chunk(i); const auto & b_chunk = b.find_chunk(i); + GGML_ASSERT(a_chunk && b_chunk); - std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get()); - std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get()); - size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get()); - size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get()); - if (ai_id == bi_id && a_pos == b_pos) { - GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen - i += a_pos - 1; // will be +1 by the for loop + + const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); + const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); + + const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get()); + const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get()); + + if (id_ai == id_bi && pos_a == pos_b) { + GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen + i += pos_a - 1; // will be +1 by the for loop continue; - } else { - return i; } - } else if (ai == bi) { - continue; - } else { + return i; } + + if (ai == bi) { + continue; + } + + return i; } + return max_idx; // all tokens are equal } @@ -1308,7 +1327,7 @@ struct server_tokens { const int32_t n_vocab = llama_vocab_n_tokens(vocab); for (size_t i = 0; i < tokens.size(); ++i) { - auto & t = tokens[i]; + const auto & t = tokens[i]; if (t == LLAMA_TOKEN_NULL) { try { const auto & chunk = find_chunk(i); @@ -1330,8 +1349,8 @@ struct server_tokens { mtmd_context * mctx, llama_pos n_past, int32_t seq_id, - llama_pos & n_pos_out) { - auto & chunk = find_chunk(n_past); + llama_pos & n_pos_out) const { + const auto & chunk = find_chunk(n_past); const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; SRV_INF("processing %s...\n", name); diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageThinkingBlock.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageThinkingBlock.svelte index 76861a66c6f23..9245ad515333e 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageThinkingBlock.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageThinkingBlock.svelte @@ -4,7 +4,6 @@ import * as Collapsible from '$lib/components/ui/collapsible/index.js'; import { buttonVariants } from '$lib/components/ui/button/index.js'; import { Card } from '$lib/components/ui/card'; - import { MarkdownContent } from '$lib/components/app'; import { config } from '$lib/stores/settings.svelte'; interface Props { @@ -59,7 +58,9 @@
- +
+ {reasoningContent ?? ''} +