diff --git a/CODEOWNERS b/CODEOWNERS index 53d2e1e7ed49e..bacc86cbbd6d2 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -65,7 +65,7 @@ /ggml/src/ggml-impl.h @ggerganov @slaren /ggml/src/ggml-metal/ @ggerganov /ggml/src/ggml-opencl/ @lhez @max-krasnyansky -/ggml/src/ggml-hexagon/ @max-krasnyansky +/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez /ggml/src/ggml-opt.cpp @JohannesGaessler /ggml/src/ggml-quants.* @ggerganov /ggml/src/ggml-rpc/ @rgerganov diff --git a/common/arg.cpp b/common/arg.cpp index a25743c899862..a465eb36234e7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3248,7 +3248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", - "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix", + "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", [](common_params & params, const std::string & value) { params.embd_out = value; } diff --git a/common/chat.cpp b/common/chat.cpp index 8587140e1ff0a..63583fb22489d 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -9,8 +9,11 @@ #include #include +#include #include +#include #include +#include #include #include #include @@ -640,6 +643,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; case COMMON_CHAT_FORMAT_APERTUS: return "Apertus"; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools"; default: throw std::runtime_error("Unknown chat format"); } @@ -986,6 +990,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat return data; } + +// Case-insensitive find +static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) { + auto it = std::search( + haystack.begin() + pos, haystack.end(), + needle.begin(), needle.end(), + [](char a, char b) { return std::tolower(a) == std::tolower(b); } + ); + return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it); +} + +static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + const auto is_json_schema_provided = !inputs.json_schema.is_null(); + const auto is_grammar_provided = !inputs.grammar.empty(); + const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty(); + + // the logic requires potentially modifying the messages + auto tweaked_messages = inputs.messages; + + auto replace_json_schema_marker = [](json & messages) -> bool { + static std::string marker1 = "force json schema.\n"; + static std::string marker2 = "force json schema."; + + if (messages.empty() || messages.at(0).at("role") != "system") { + return false; + } + + std::string content = messages.at(0).at("content"); + + for (const auto & marker : {marker1, marker2}) { + const auto pos = ifind_string(content, marker); + if (pos != std::string::npos) { + content.replace(pos, marker.length(), ""); + // inject modified content back into the messages + messages.at(0).at("content") = content; + return true; + } + } + + return false; + }; + + // Lfm2 model does not natively work with json, but can generally understand the tools structure + // + // Example of the pytorch dialog structure: + // <|startoftext|><|im_start|>system + // List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|> + // <|im_start|>user + // What is the current status of candidate ID 12345?<|im_end|> + // <|im_start|>assistant + // <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|> + // <|im_start|>tool + // <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|> + // <|im_start|>assistant + // The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|> + // + // For the llama server compatibility with json tools semantic, + // the client can add "Follow json schema." line into the system message prompt to force the json output. + // + if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) { + // server/utils.hpp prohibits that branch for the custom grammar anyways + throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar"); + } else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) { + LOG_INF("%s: Using tools to build a grammar\n", __func__); + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function.at("name")}, + }}, + {"arguments", function.at("parameters")}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + + builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\""); + }); + // model has no concept of tool selection mode choice, + // if the system prompt rendered correctly it will produce a tool call + // the grammar goes inside the tool call body + data.grammar_lazy = true; + data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}}; + data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; + data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS; + } else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) { + LOG_INF("%s: Using tools without json schema or grammar\n", __func__); + // output those tokens + data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; + } else if (is_json_schema_provided) { + LOG_INF("%s: Using provided json schema to build a grammar\n", __func__); + data.grammar = json_schema_to_grammar(inputs.json_schema); + } else if (is_grammar_provided) { + LOG_INF("%s: Using provided grammar\n", __func__); + data.grammar = inputs.grammar; + } else { + LOG_INF("%s: Using content relying on the template\n", __func__); + } + + data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); + LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str()); + + return data; +} + static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -2499,6 +2623,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) { builder.add_content(builder.consume_rest()); } + +static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> + static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); + static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); + + // Loop through all tool calls + while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { + builder.move_to(res->groups[0].end); + + // Parse JSON array format: [{"name": "...", "arguments": {...}}] + auto tool_calls_data = builder.consume_json(); + + // Consume end marker + builder.consume_spaces(); + if (!builder.try_consume_regex(tool_call_end_regex)) { + throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); + } + + // Process each tool call in the array + if (tool_calls_data.json.is_array()) { + for (const auto & tool_call : tool_calls_data.json) { + if (!tool_call.is_object()) { + throw common_chat_msg_partial_exception("Tool call must be an object"); + } + + if (!tool_call.contains("name")) { + throw common_chat_msg_partial_exception("Tool call missing 'name' field"); + } + + std::string function_name = tool_call.at("name"); + std::string arguments = "{}"; + + if (tool_call.contains("arguments")) { + if (tool_call.at("arguments").is_object()) { + arguments = tool_call.at("arguments").dump(); + } else if (tool_call.at("arguments").is_string()) { + arguments = tool_call.at("arguments"); + } + } + + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + } else { + throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); + } + + // Consume any trailing whitespace after this tool call + builder.consume_spaces(); + } + + // Consume any remaining content after all tool calls + auto remaining = builder.consume_rest(); + if (!string_strip(remaining).empty()) { + builder.add_content(remaining); + } +} + static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { // Parse thinking tags first - this handles the main reasoning content builder.try_parse_reasoning("", ""); @@ -2748,6 +2937,12 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_apertus(tmpl, params); } + // LFM2 (w/ tools) + if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos && + src.find("]<|tool_list_end|>") != std::string::npos) { + return common_chat_params_init_lfm2(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -2926,6 +3121,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_APERTUS: common_chat_parse_apertus(builder); break; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: + common_chat_parse_lfm2(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.h b/common/chat.h index f7b36ec711df4..50efb0d4e516f 100644 --- a/common/chat.h +++ b/common/chat.h @@ -116,6 +116,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_SEED_OSS, COMMON_CHAT_FORMAT_NEMOTRON_V2, COMMON_CHAT_FORMAT_APERTUS, + COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index dd9b51a9e50fd..478aa1be7b5b8 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -601,7 +601,10 @@ class SchemaConverter { } std::string _resolve_ref(const std::string & ref) { - std::string ref_name = ref.substr(ref.find_last_of('/') + 1); + auto it = ref.find('#'); + std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref; + static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)"); + std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-"); if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) { _refs_being_resolved.insert(ref); json resolved = _refs[ref]; @@ -774,11 +777,24 @@ class SchemaConverter { std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { std::string sel = tokens[i]; - if (target.is_null() || !target.contains(sel)) { + if (target.is_object() && target.contains(sel)) { + target = target[sel]; + } else if (target.is_array()) { + size_t sel_index; + try { + sel_index = std::stoul(sel); + } catch (const std::invalid_argument & e) { + sel_index = target.size(); + } + if (sel_index >= target.size()) { + _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); + return; + } + target = target[sel_index]; + } else { _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); return; } - target = target[sel]; } _refs[ref] = target; } diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 093f2ab467f4d..b759366684396 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2460,18 +2460,21 @@ def set_gguf_parameters(self): ) class LlavaVisionModel(MmprojModel): img_break_tok_id = -1 + use_break_tok = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.hparams.get("model_type") == "pixtral": # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5) - self.img_break_tok_id = self.get_token_id("[IMG_BREAK]") + if self.use_break_tok: + self.img_break_tok_id = self.get_token_id("[IMG_BREAK]") elif self.is_mistral_format: # hparams is already vision config here so norm_eps is only defined in global_config. self.hparams["norm_eps"] = self.global_config.get("norm_eps", None) assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json" - self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) + if self.use_break_tok: + self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) else: raise ValueError(f"Unsupported model type: {self.hparams['model_type']}") logger.info(f"Image break token id: {self.img_break_tok_id}") @@ -3962,6 +3965,10 @@ def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor: return torch.stack([true_row, false_row], dim=0) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "model.vision_" in name: + # skip multimodal tensors + return [] + if self.is_rerank: is_tied_head = self.is_tied_embeddings and "embed_tokens" in name is_real_head = not self.is_tied_embeddings and "lm_head" in name @@ -9435,6 +9442,21 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", " return super().map_tensor_name(name, try_suffixes) +@ModelBase.register("LightOnOCRForConditionalGeneration") +class LightOnOCRVisionModel(LlavaVisionModel): + is_mistral_format = False + use_break_tok = False + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + name = name.replace("model.vision_encoder.", "vision_tower.") + name = name.replace("model.vision_projection.", "multi_modal_projector.") + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("KimiVLForConditionalGeneration") class KimiVLModel(MmprojModel): def __init__(self, *args, **kwargs): diff --git a/docs/build.md b/docs/build.md index dcbcce7549ad2..b410c710e30d3 100644 --- a/docs/build.md +++ b/docs/build.md @@ -261,10 +261,12 @@ You can download it from your Linux distro's package manager or from here: [ROCm - Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU): ```bash HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ - cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ + cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ && cmake --build build --config Release -- -j 16 ``` + Note: `GPU_TARGETS` is optional, omitting it will build the code for all GPUs in the current system. + To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system. The rocWMMA library is included by default when installing the ROCm SDK using the `rocm` meta package provided by AMD. Alternatively, if you are not using the meta package, you can install the library using the `rocwmma-dev` or `rocwmma-devel` package, depending on your system's package manager. @@ -282,17 +284,17 @@ You can download it from your Linux distro's package manager or from here: [ROCm ```bash HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -p)" \ HIP_DEVICE_LIB_PATH= \ - cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ + cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ && cmake --build build -- -j 16 ``` - Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU): ```bash set PATH=%HIP_PATH%\bin;%PATH% - cmake -S . -B build -G Ninja -DAMDGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release + cmake -S . -B build -G Ninja -DGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release cmake --build build ``` - Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors) + If necessary, adapt `GPU_TARGETS` to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors) Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`. diff --git a/examples/embedding/README.md b/examples/embedding/README.md index 3dd279d9fc41a..1684f36480d82 100644 --- a/examples/embedding/README.md +++ b/examples/embedding/README.md @@ -38,6 +38,7 @@ The above command will output space-separated float values. | | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$ | 'json' | openai style | | 'json+' | add cosine similarity matrix | +| 'raw' | plain text output | ### --embd-separator $"string"$ | $"string"$ | | diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 388908bc4d70a..9e3ab5905bb37 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -70,6 +70,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } } +// plain, pipe-friendly output: one embedding per line +static void print_raw_embeddings(const float * emb, + int n_embd_count, + int n_embd, + const llama_model * model, + enum llama_pooling_type pooling_type, + int embd_normalize) { + const uint32_t n_cls_out = llama_model_n_cls_out(model); + const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK); + const int cols = is_rank ? std::min(n_embd, (int) n_cls_out) : n_embd; + + for (int j = 0; j < n_embd_count; ++j) { + for (int i = 0; i < cols; ++i) { + if (embd_normalize == 0) { + LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); + } else { + LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); + } + } + LOG("\n"); + } +} + int main(int argc, char ** argv) { common_params params; @@ -372,6 +395,8 @@ int main(int argc, char ** argv) { } if (notArray) LOG("\n}\n"); + } else if (params.embd_out == "raw") { + print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize); } LOG("\n"); diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 2d57549046b88..26989157fe6b6 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -371,8 +371,17 @@ def visit(n: dict): raise ValueError(f'Unsupported ref {ref}') for sel in ref.split('#')[-1].split('/')[1:]: - assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' - target = target[sel] + assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}' + if isinstance(target, list): + try: + sel_index = int(sel) + except ValueError: + raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}') + assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel_index] + else: + assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] self._refs[ref] = target else: @@ -547,7 +556,8 @@ def join_seq(): def _resolve_ref(self, ref): - ref_name = ref.split('/')[-1] + ref_fragment = ref.split('#')[-1] + ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment) if ref_name not in self._rules and ref not in self._refs_being_resolved: self._refs_being_resolved.add(ref) resolved = self._refs[ref] diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index f030ea0136a95..5df6dc96a3b2e 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, ACL_MEM_MALLOC_HUGE_FIRST)); acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + theta_scale_ne, theta_scale_nb, 1); float start = 0; float step = 1; @@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); void * yarn_ramp_buffer = yarn_ramp_allocator.get(); acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, - theta_scale_nb, GGML_MAX_DIMS); + theta_scale_nb, 1); float zero_value = 0, one_value = 1; float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 8bd5449f1f75f..51345742ee59e 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -67,19 +67,30 @@ GGML_ABORT("CANN error"); } +// Thread-local variable to record the current device of this thread. +thread_local int g_current_cann_device = -1; + /** - * @brief Sets the device to be used by CANN. + * @brief Set the CANN device to be used. * - * @param device The device ID to set. + * @param device The target device ID to set. */ void ggml_cann_set_device(const int32_t device) { - int current_device = -1; - aclrtGetDevice(¤t_device); + // int current_device = -1; + // Note: In some CANN versions, if no device has been set yet, + // aclrtGetDevice(¤t_device) may return 0 by default. + // aclrtGetDevice(¤t_device); - if (device == current_device) { + // If the current device is already the target one, no need to switch. + if (device == g_current_cann_device) { return; } + + // Switch to the new device. ACL_CHECK(aclrtSetDevice(device)); + + // Update the global device record. + g_current_cann_device = device; } /** diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b52f0f8472cfe..3156bd60101d7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7519,8 +7519,8 @@ static void ggml_compute_forward_upscale_f32( float pixel_offset = 0.5f; if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { pixel_offset = 0.0f; - sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1); - sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1); + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; } for (int64_t i3 = 0; i3 < ne3; i3++) { diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index c2c31cdaf231b..4e31783436d80 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -343,6 +343,10 @@ static __global__ void mul_mat_vec_f( } dst[tid*stride_col_dst + row] = value; + + if constexpr (!has_fusion) { + GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate); + } } template diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 7a783e4fcf9b4..be04a85cc5515 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -310,6 +310,10 @@ static __global__ void mul_mat_vec_q( dst[j*stride_col_dst + threadIdx.x] = result; } } + + if constexpr (!has_fusion) { + GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate); + } } static std::pair calc_launch_params( diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index ef48aa5f97bcd..35b7e61d80ac9 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -126,8 +126,8 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } else if (mode == GGML_SCALE_MODE_BILINEAR) { float pixel_offset = 0.5f; if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { - sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); - sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0; + sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1; pixel_offset = 0.0f; } upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index ecfc1c856cb59..2d376a6025c07 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -211,12 +211,15 @@ static inline void hex_format_op_names(char * str, const struct ggml_tensor * t) // ** backend sessions struct ggml_hexagon_session { - ggml_hexagon_session(int dev_id) noexcept(false); + ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); ~ggml_hexagon_session() noexcept(true); void allocate(int dev_id) noexcept(false); void release() noexcept(true); + void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false); + void flush(); + ggml_backend_buffer_type buffer_type; ggml_backend_buffer_type repack_buffer_type; @@ -237,15 +240,37 @@ struct ggml_hexagon_session { uint32_t prof_pkts; }; -// Packet callback -static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * context) { - auto sess = static_cast(context); +void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) { + // Bump pending flag (cleared in the session::flush once we get the responce) + this->op_pending++; // atomic inc + + int err = dspqueue_write(this->queue, + 0, // flags - the framework will autoset this + n_bufs, // number of buffers + bufs, // buffer references + sizeof(req), + (const uint8_t *) &req, // Message + 1000000 // Timeout + ); + + if (err != 0) { + GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err); + } + + if (sync) { + flush(); + } +} + +// Flush HTP response queue i.e wait for all outstanding requests to complete +void ggml_hexagon_session::flush() { + dspqueue_t q = this->queue; // Repeatedly read packets from the queue until it's empty. We don't // necessarily get a separate callback for each packet, and new packets // may arrive while we're processing the previous one. - while (1) { + while (this->op_pending) { struct htp_general_rsp rsp; uint32_t rsp_size; uint32_t flags; @@ -253,22 +278,23 @@ static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * contex struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; uint32_t n_bufs; - // Read packet from queue - int err = dspqueue_read_noblock(queue, &flags, - HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references - &n_bufs, // Number of buffer references - bufs, // Buffer references - sizeof(rsp), // Max message length - &rsp_size, // Message length - (uint8_t *) &rsp); - - if (err == AEE_EWOULDBLOCK) { - // Consumed all packets available for now - return; + // Read response packet from queue + int err = dspqueue_read(q, &flags, + HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references + &n_bufs, // Number of buffer references + bufs, // Buffer references + sizeof(rsp), // Max message length + &rsp_size, // Message length + (uint8_t *) &rsp, + 1000000); // Timeout + + if (err == AEE_EEXPIRED) { + // TODO: might need to bail out if the HTP is stuck on something + continue; } if (err != 0) { - GGML_ABORT("ggml-hex: dspqueue_read_noblock failed: 0x%08x\n", (unsigned) err); + GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err); } // Basic sanity checks @@ -281,21 +307,15 @@ static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * contex // TODO: handle errors } - // FIXME: update profiling implementation - sess->prof_usecs = rsp.prof_usecs; - sess->prof_cycles = rsp.prof_cycles; - sess->prof_pkts = rsp.prof_pkts; + // TODO: update profiling implementation, currently only works for opt_opsync mode + this->prof_usecs = rsp.prof_usecs; + this->prof_cycles = rsp.prof_cycles; + this->prof_pkts = rsp.prof_pkts; - sess->op_pending--; // atomic dec + this->op_pending--; // atomic dec } } -// Error callback - simply terminates with an error. Used where we don't -// expect errors. -[[noreturn]] static void htp_error_callback(dspqueue_t queue, AEEResult error, void * context) { - GGML_ABORT("ggml-hex: dspcall general error 0x%x: for queue %p\n", error, (void *) queue); -} - // ** backend buffers struct ggml_backend_hexagon_buffer_type_context { @@ -1564,7 +1584,8 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { 0, // Flags 128 * 1024, // Request queue size (in bytes) 64 * 1024, // Response queue size (in bytes) - htp_packet_callback, htp_error_callback, + nullptr, // Read packet callback (we handle reads explicitly) + nullptr, // Error callback (we handle errors during reads) (void *) this, // Callback context &queue); if (err != 0) { @@ -1631,10 +1652,13 @@ void ggml_hexagon_session::release() noexcept(true) { } } -ggml_hexagon_session::ggml_hexagon_session(int dev_id) noexcept(false) { +ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) { buffer_type.context = nullptr; repack_buffer_type.context = nullptr; + buffer_type.device = dev; + repack_buffer_type.device = dev; + try { allocate(dev_id); @@ -2202,7 +2226,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) bufs[0].ptr = src0->data; bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = DSPQUEUE_BUFFER_FLAG_REF; + bufs[0].flags = 0; // Second buffer Input Activations. This is a buffer that the CPU // writes and the DSP reads, so we'll need to flush CPU caches and @@ -2212,8 +2236,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) bufs[1].ptr = src1->data; bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP // Third buffer Output Activations. We'll handle DSP @@ -2224,7 +2247,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) bufs[2].ptr = dst->data; bufs[2].offset = (uint8_t *) dst->data - dst_buf->base; bufs[2].size = ggml_nbytes(dst); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); // Primary DSP session from the src0 (normally weight) tensor auto sess = src0_buf->sess; @@ -2252,27 +2275,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) } if ((opt_opmask & HTP_OPMASK_QUEUE)) { - // Bump pending flag (cleared in the callback once we get the responce) - sess->op_pending++; // atomic inc - - int err = dspqueue_write(sess->queue, - 0, // flags - the framework will autoset this - 3, // number of buffers - bufs, // buffer references - sizeof(req), - (const uint8_t *) &req, // Message - 1000000 // Timeout - ); - - if (err != 0) { - GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err); - } - } - - if (opt_opsync) { - while (sess->op_pending) { - ; - } + sess->enqueue(req, bufs, 3, opt_opsync); } t2 = ggml_time_us(); @@ -2328,7 +2331,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag bufs[0].ptr = src0->data; bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = DSPQUEUE_BUFFER_FLAG_REF; + bufs[0].flags = 0; // Second buffer Input Activations. This is a buffer that the CPU // writes and the DSP reads, so we'll need to flush CPU caches and @@ -2338,8 +2341,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag bufs[1].ptr = src1->data; bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP // Third buffer expert IDs. This is a buffer that the CPU @@ -2350,8 +2352,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag bufs[2].ptr = src2->data; bufs[2].offset = (uint8_t *) src2->data - src2_buf->base; bufs[2].size = ggml_nbytes(src2); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP // Forth buffer Output Activations. We'll handle DSP @@ -2362,7 +2363,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag bufs[3].ptr = dst->data; bufs[3].offset = (uint8_t *) dst->data - dst_buf->base; bufs[3].size = ggml_nbytes(dst); - bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); // Primary DSP session from the src0 (normally weight) tensor auto sess = src0_buf->sess; @@ -2391,27 +2392,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag } if ((opt_opmask & HTP_OPMASK_QUEUE)) { - // Bump pending flag (cleared in the callback once we get the responce) - sess->op_pending++; // atomic inc - - int err = dspqueue_write(sess->queue, - 0, // flags - the framework will autoset this - 4, // number of buffers - bufs, // buffer references - sizeof(req), - (const uint8_t *) &req, // Message - 1000000 // Timeout - ); - - if (err != 0) { - GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err); - } - } - - if (opt_opsync) { - while (sess->op_pending) { - ; - } + sess->enqueue(req, bufs, 4, opt_opsync); } t2 = ggml_time_us(); @@ -2484,8 +2465,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) { bufs[0].ptr = src0->data; bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; // Second buffer = Second Operand of Binary op @@ -2497,8 +2477,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) { bufs[1].ptr = src1->data; bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP // Third buffer = Output Activations. We'll handle DSP @@ -2509,7 +2488,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) { bufs[2].ptr = dst->data; bufs[2].offset = (uint8_t *) dst->data - dst_buf->base; bufs[2].size = ggml_nbytes(dst); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); // Primary DSP session from the src0 tensor ggml_hexagon_session * sess = src0_buf->sess; @@ -2537,26 +2516,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) { } if ((opt_opmask & HTP_OPMASK_QUEUE)) { - // Bump pending flag (cleared in the callback once we get the responce) - sess->op_pending++; // atomic inc - - int err = dspqueue_write(sess->queue, - 0, // flags - the framework will autoset this - 3, // number of buffers - bufs, // buffer references - sizeof(req), - (const uint8_t *) &req, // Message - 1000000); // Timeout - - if (0 != err) { - GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err); - } - } - - if (opt_opsync) { - while (sess->op_pending) { - ; - } + sess->enqueue(req, bufs, 3, opt_opsync); } t2 = ggml_time_us(); @@ -2621,8 +2581,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) { bufs[0].ptr = src0->data; bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; bufs[0].size = ggml_nbytes(src0); - bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; // Second buffer = experts bias @@ -2630,8 +2589,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) { bufs[1].ptr = src1->data; bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; bufs[1].size = ggml_nbytes(src1); - bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP // Third buffer = activated experts @@ -2639,8 +2597,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) { bufs[2].ptr = src2->data; bufs[2].offset = (uint8_t *) src2->data - src2_buf->base; bufs[2].size = ggml_nbytes(src2); - bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP // Forth buffer = output activations @@ -2648,7 +2605,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) { bufs[3].ptr = dst->data; bufs[3].offset = (uint8_t *) dst->data - dst_buf->base; bufs[3].size = ggml_nbytes(dst); - bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); // Primary DSP session from the src0 tensor ggml_hexagon_session * sess = src0_buf->sess; @@ -2678,26 +2635,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) { } if ((opt_opmask & HTP_OPMASK_QUEUE)) { - // Bump pending flag (cleared in the callback once we get the responce) - sess->op_pending++; // atomic inc - - int err = dspqueue_write(sess->queue, - 0, // flags - the framework will autoset this - 4, // number of buffers - bufs, // buffer references - sizeof(req), - (const uint8_t *) &req, // Message - 1000000); // Timeout - - if (0 != err) { - GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err); - } - } - - if (opt_opsync) { - while (sess->op_pending) { - ; - } + sess->enqueue(req, bufs, 4, opt_opsync); } t2 = ggml_time_us(); @@ -2795,8 +2733,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) { bufs[n_bufs].ptr = src0->data; bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base; bufs[n_bufs].size = ggml_nbytes(src0); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; ++n_bufs; @@ -2811,8 +2748,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) { bufs[n_bufs].ptr = src1->data; bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base; bufs[n_bufs].size = ggml_nbytes(src1); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP ++n_bufs; } @@ -2827,7 +2763,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) { bufs[n_bufs].ptr = dst->data; bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base; bufs[n_bufs].size = ggml_nbytes(dst); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); ++n_bufs; // Primary DSP session from the src0 tensor @@ -2860,26 +2796,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) { } if ((opt_opmask & HTP_OPMASK_QUEUE)) { - // Bump pending flag (cleared in the callback once we get the responce) - sess->op_pending++; // atomic inc - - int err = dspqueue_write(sess->queue, - 0, // flags - the framework will autoset this - n_bufs, // number of buffers - bufs, // buffer references - sizeof(req), - (const uint8_t *) &req, // Message - 1000000); // Timeout - - if (0 != err) { - GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err); - } - } - - if (opt_opsync) { - while (sess->op_pending) { - ; - } + sess->enqueue(req, bufs, n_bufs, opt_opsync); } t2 = ggml_time_us(); @@ -2953,8 +2870,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) { bufs[n_bufs].ptr = src0->data; bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base; bufs[n_bufs].size = ggml_nbytes(src0); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; ++n_bufs; @@ -2968,8 +2884,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) { bufs[n_bufs].ptr = src1->data; bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base; bufs[n_bufs].size = ggml_nbytes(src1); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP ++n_bufs; @@ -2984,8 +2899,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) { bufs[n_bufs].ptr = src2->data; bufs[n_bufs].offset = (uint8_t *) src2->data - src2_buf->base; bufs[n_bufs].size = ggml_nbytes(src2); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP ++n_bufs; } @@ -3000,7 +2914,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) { bufs[n_bufs].ptr = dst->data; bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base; bufs[n_bufs].size = ggml_nbytes(dst); - bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); ++n_bufs; // Primary DSP session from the src0 tensor @@ -3033,26 +2947,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) { } if ((opt_opmask & HTP_OPMASK_QUEUE)) { - // Bump pending flag (cleared in the callback once we get the responce) - sess->op_pending++; // atomic inc - - int err = dspqueue_write(sess->queue, - 0, // flags - the framework will autoset this - n_bufs, // number of buffers - bufs, // buffer references - sizeof(req), - (const uint8_t *) &req, // Message - 1000000); // Timeout - - if (0 != err) { - GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err); - } - } - - if (opt_opsync) { - while (sess->op_pending) { - ; - } + sess->enqueue(req, bufs, n_bufs, opt_opsync); } t2 = ggml_time_us(); @@ -3197,9 +3092,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg } // Wait until all pending ops complete - while (sess->op_pending) { - ; - } + sess->flush(); return GGML_STATUS_SUCCESS; } @@ -3210,9 +3103,7 @@ static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) { HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str()); // Wait until all pending ops complete - while (sess->op_pending) { - ; - } + sess->flush(); } struct node_info { @@ -3628,7 +3519,7 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { devices[i].iface = ggml_backend_hexagon_device_i; devices[i].reg = reg; try { - devices[i].context = new ggml_hexagon_session(i); + devices[i].context = new ggml_hexagon_session(i, &devices[i]); } catch (std::exception const &exc) { GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i); devices[i].context = nullptr; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e35ea3b0211c8..10e2733324354 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -395,28 +395,14 @@ static void proc_matmul_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, size_t n_bufs) { - // Prep response buffer structs (needed for error responses, etc) - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - memset(rsp_bufs, 0, sizeof(rsp_bufs)); - rsp_bufs[0].fd = bufs[0].fd; - rsp_bufs[0].ptr = bufs[0].ptr; - rsp_bufs[0].size = bufs[0].size; - rsp_bufs[0].offset = bufs[0].offset; - rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference - - rsp_bufs[1].fd = bufs[1].fd; - rsp_bufs[1].ptr = bufs[1].ptr; - rsp_bufs[1].size = bufs[1].size; - rsp_bufs[1].offset = bufs[1].offset; - rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference + struct dspqueue_buffer rsp_bufs[1]; // We had written to the output buffer, we'd also need to flush it - rsp_bufs[2].fd = bufs[2].fd; - rsp_bufs[2].ptr = bufs[2].ptr; - rsp_bufs[2].size = bufs[2].size; - rsp_bufs[2].offset = bufs[2].offset; - rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU // Setup Op context @@ -444,41 +430,21 @@ static void proc_matmul_req(struct htp_context * ctx, } profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } static void proc_matmul_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, size_t n_bufs) { - // Prep response buffer structs (needed for error responses, etc) - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - memset(rsp_bufs, 0, sizeof(rsp_bufs)); - rsp_bufs[0].fd = bufs[0].fd; - rsp_bufs[0].ptr = bufs[0].ptr; - rsp_bufs[0].size = bufs[0].size; - rsp_bufs[0].offset = bufs[0].offset; - rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference - - rsp_bufs[1].fd = bufs[1].fd; - rsp_bufs[1].ptr = bufs[1].ptr; - rsp_bufs[1].size = bufs[1].size; - rsp_bufs[1].offset = bufs[1].offset; - rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference - - rsp_bufs[2].fd = bufs[2].fd; - rsp_bufs[2].ptr = bufs[2].ptr; - rsp_bufs[2].size = bufs[2].size; - rsp_bufs[2].offset = bufs[2].offset; - rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference + struct dspqueue_buffer rsp_bufs[1]; // We had written to the output buffer, we'd also need to flush it - rsp_bufs[3].fd = bufs[3].fd; - rsp_bufs[3].ptr = bufs[3].ptr; - rsp_bufs[3].size = bufs[3].size; - rsp_bufs[3].offset = bufs[3].offset; - rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP + rsp_bufs[0].fd = bufs[3].fd; + rsp_bufs[0].ptr = bufs[3].ptr; + rsp_bufs[0].size = bufs[3].size; + rsp_bufs[0].offset = bufs[3].offset; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU // Setup Op context @@ -508,32 +474,18 @@ static void proc_matmul_id_req(struct htp_context * ctx, } profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - memset(rsp_bufs, 0, sizeof(rsp_bufs)); - - rsp_bufs[0].fd = bufs[0].fd; - rsp_bufs[0].ptr = bufs[0].ptr; - rsp_bufs[0].offset = bufs[0].offset; - rsp_bufs[0].size = bufs[0].size; - rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference - - rsp_bufs[1].fd = bufs[1].fd; - rsp_bufs[1].ptr = bufs[1].ptr; - rsp_bufs[1].offset = bufs[1].offset; - rsp_bufs[1].size = bufs[1].size; - rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference + struct dspqueue_buffer rsp_bufs[1]; // We had written to the output buffer, we'd also need to flush it - rsp_bufs[2].fd = bufs[2].fd; - rsp_bufs[2].ptr = bufs[2].ptr; - rsp_bufs[2].offset = bufs[2].offset; - rsp_bufs[2].size = bufs[2].size; - rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU // Setup Op context @@ -561,38 +513,18 @@ static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * r } profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - memset(rsp_bufs, 0, sizeof(rsp_bufs)); - - rsp_bufs[0].fd = bufs[0].fd; - rsp_bufs[0].ptr = bufs[0].ptr; - rsp_bufs[0].offset = bufs[0].offset; - rsp_bufs[0].size = bufs[0].size; - rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference - - rsp_bufs[1].fd = bufs[1].fd; - rsp_bufs[1].ptr = bufs[1].ptr; - rsp_bufs[1].offset = bufs[1].offset; - rsp_bufs[1].size = bufs[1].size; - rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference - - rsp_bufs[2].fd = bufs[2].fd; - rsp_bufs[2].ptr = bufs[2].ptr; - rsp_bufs[2].offset = bufs[2].offset; - rsp_bufs[2].size = bufs[2].size; - rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference + struct dspqueue_buffer rsp_bufs[1]; // We had written to the output buffer, we'd also need to flush it - rsp_bufs[3].fd = bufs[3].fd; - rsp_bufs[3].ptr = bufs[3].ptr; - rsp_bufs[3].offset = bufs[3].offset; - rsp_bufs[3].size = bufs[3].size; - rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP + rsp_bufs[0].fd = bufs[3].fd; + rsp_bufs[0].ptr = bufs[3].ptr; + rsp_bufs[0].offset = bufs[3].offset; + rsp_bufs[0].size = bufs[3].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU // Setup Op context @@ -622,26 +554,18 @@ static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * r } profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - memset(rsp_bufs, 0, sizeof(rsp_bufs)); - - rsp_bufs[0].fd = bufs[0].fd; - rsp_bufs[0].ptr = bufs[0].ptr; - rsp_bufs[0].offset = bufs[0].offset; - rsp_bufs[0].size = bufs[0].size; - rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference // We had written to the output buffer, we'd also need to flush it - rsp_bufs[1].fd = bufs[1].fd; - rsp_bufs[1].ptr = bufs[1].ptr; - rsp_bufs[1].offset = bufs[1].offset; - rsp_bufs[1].size = bufs[1].size; - rsp_bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU // Setup Op context @@ -669,7 +593,7 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re } profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 2, &prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } static void proc_activations_req(struct htp_context * ctx, @@ -677,33 +601,16 @@ static void proc_activations_req(struct htp_context * ctx, struct dspqueue_buffer * bufs, uint32_t n_bufs) { struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - memset(rsp_bufs, 0, sizeof(rsp_bufs)); - - rsp_bufs[0].fd = bufs[0].fd; - rsp_bufs[0].ptr = bufs[0].ptr; - rsp_bufs[0].offset = bufs[0].offset; - rsp_bufs[0].size = bufs[0].size; - rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference - int write_idx = 1; - if (3 == n_bufs) { - rsp_bufs[1].fd = bufs[1].fd; - rsp_bufs[1].ptr = bufs[1].ptr; - rsp_bufs[1].offset = bufs[1].offset; - rsp_bufs[1].size = bufs[1].size; - rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference - - write_idx = 2; - } + int write_idx = (n_bufs == 3) ? 2 : 1; // We had written to the output buffer, we'd also need to flush it - rsp_bufs[write_idx].fd = bufs[write_idx].fd; - rsp_bufs[write_idx].ptr = bufs[write_idx].ptr; - rsp_bufs[write_idx].offset = bufs[write_idx].offset; - rsp_bufs[write_idx].size = bufs[write_idx].size; - rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + rsp_bufs[0].fd = bufs[write_idx].fd; + rsp_bufs[0].ptr = bufs[write_idx].ptr; + rsp_bufs[0].offset = bufs[write_idx].offset; + rsp_bufs[0].size = bufs[write_idx].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU // Setup Op context struct htp_ops_context octx = { 0 }; @@ -742,7 +649,7 @@ static void proc_activations_req(struct htp_context * ctx, } profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } static void proc_rope_req(struct htp_context * ctx, @@ -750,39 +657,16 @@ static void proc_rope_req(struct htp_context * ctx, struct dspqueue_buffer * bufs, uint32_t n_bufs) { struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - memset(rsp_bufs, 0, sizeof(rsp_bufs)); - - rsp_bufs[0].fd = bufs[0].fd; - rsp_bufs[0].ptr = bufs[0].ptr; - rsp_bufs[0].offset = bufs[0].offset; - rsp_bufs[0].size = bufs[0].size; - rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference - rsp_bufs[1].fd = bufs[1].fd; - rsp_bufs[1].ptr = bufs[1].ptr; - rsp_bufs[1].offset = bufs[1].offset; - rsp_bufs[1].size = bufs[1].size; - rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference - - int write_idx = 2; - if (4 == n_bufs) { - rsp_bufs[write_idx].fd = bufs[write_idx].fd; - rsp_bufs[write_idx].ptr = bufs[write_idx].ptr; - rsp_bufs[write_idx].offset = bufs[write_idx].offset; - rsp_bufs[write_idx].size = bufs[write_idx].size; - rsp_bufs[write_idx].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference - - write_idx++; - } + int write_idx = (n_bufs == 4) ? 3 : 2; // We had written to the output buffer, we'd also need to flush it - rsp_bufs[write_idx].fd = bufs[write_idx].fd; - rsp_bufs[write_idx].ptr = bufs[write_idx].ptr; - rsp_bufs[write_idx].offset = bufs[write_idx].offset; - rsp_bufs[write_idx].size = bufs[write_idx].size; - rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference - DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + rsp_bufs[0].fd = bufs[write_idx].fd; + rsp_bufs[0].ptr = bufs[write_idx].ptr; + rsp_bufs[0].offset = bufs[write_idx].offset; + rsp_bufs[0].size = bufs[write_idx].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU // Setup Op context struct htp_ops_context octx = { 0 }; @@ -819,7 +703,7 @@ static void proc_rope_req(struct htp_context * ctx, } profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } static void htp_packet_callback(dspqueue_t queue, int error, void * context) { diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 6b499320e7b12..23b6889919f20 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -29,10 +29,11 @@ if (CXX_IS_HIPCC) endif() else() # Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES. + if(AMDGPU_TARGETS AND NOT GPU_TARGETS) + set(GPU_TARGETS ${AMDGPU_TARGETS}) + endif() if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS}) - elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) - set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) endif() cmake_minimum_required(VERSION 3.21) enable_language(HIP) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index db33a4ab6c2e3..93a3600b63f07 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -6156,8 +6156,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3)); } else if (mode == GGML_SCALE_MODE_BILINEAR) { if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { - sf0 = (float)(ne0 - 1) / (ne00 - 1); - sf1 = (float)(ne1 - 1) / (ne01 - 1); + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; pixel_offset = 0.0f; } diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index ca53f3e90068c..75657f3fca2e7 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -35,6 +35,7 @@ #include "roll.hpp" #include "rope.hpp" #include "set_rows.hpp" +#include "ssm_conv.hpp" #include "softmax.hpp" #include "tsembd.hpp" #include "wkv.hpp" diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 62d0ecd94ee0a..328d1a71b7580 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -50,6 +50,7 @@ #include "ggml-sycl/getrows.hpp" #include "ggml-sycl/repeat_back.hpp" #include "ggml-sycl/quantize.hpp" +#include "ggml-sycl/ssm_conv.hpp" #include "ggml.h" static bool g_sycl_loaded = false; @@ -3921,6 +3922,8 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_sycl_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_SSM_CONV: + ggml_sycl_ssm_conv(ctx, dst); case GGML_OP_ROLL: ggml_sycl_roll(ctx, dst); break; @@ -4602,6 +4605,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_RWKV_WKV7: case GGML_OP_GATED_LINEAR_ATTN: return true; + case GGML_OP_SSM_CONV: + return op->type == GGML_TYPE_F32 && + op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32; case GGML_OP_ROLL: return op->type == GGML_TYPE_F32; case GGML_OP_ARANGE: diff --git a/ggml/src/ggml-sycl/ssm_conv.cpp b/ggml/src/ggml-sycl/ssm_conv.cpp new file mode 100644 index 0000000000000..0dc0f71c9a157 --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_conv.cpp @@ -0,0 +1,127 @@ +#include "ssm_conv.hpp" +#include "common.hpp" + +#include + +using namespace sycl; + +static void kernel_ssm_conv( + queue &q, + const float *src_data, + const float *weights, + float *dst_data, + int d_conv, + int d_inner, + int n_t, + int n_s, + int ncs __attribute__((unused)), + int src_stride_inner, + int src_stride_seq, + int dst_stride_token, + int dst_stride_seq +) { + const size_t total_work = static_cast(d_inner) * static_cast(n_t) * static_cast(n_s); + const size_t work_group_size = 256; + const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size; + + const range<1> global_range(num_work_groups * work_group_size); + const range<1> local_range(work_group_size); + + q.submit([&](handler &h) { + h.parallel_for( + nd_range<1>(global_range, local_range), + [=](nd_item<1> item) { + const size_t idx = item.get_global_id(0); + if (idx >= total_work) { + return; + } + + const int channel = static_cast(idx % d_inner); + const int token = static_cast((idx / d_inner) % n_t); + const int seq = static_cast(idx / (static_cast(d_inner) * static_cast(n_t))); + + const float *s = src_data + + static_cast(seq) * static_cast(src_stride_seq) + + static_cast(channel) * static_cast(src_stride_inner) + + static_cast(token); + + const float *c = weights + static_cast(channel) * static_cast(d_conv); + + float sumf = 0.0f; + for (int i0 = 0; i0 < d_conv; ++i0) { + sumf += s[i0] * c[i0]; + } + + const size_t dst_idx = + static_cast(seq) * static_cast(dst_stride_seq) + + static_cast(token) * static_cast(dst_stride_token) + + static_cast(channel); + + dst_data[dst_idx] = sumf; + } + ); + }); +} + +void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int d_conv = src1->ne[0]; + const int ncs = src0->ne[0]; + const int d_inner = src0->ne[1]; + const int n_t = dst->ne[1]; + const int n_s = dst->ne[2]; + + GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t); + GGML_ASSERT(src0->ne[1] == d_inner); + GGML_ASSERT(src1->ne[1] == d_inner); + + GGML_ASSERT(dst->ne[0] == d_inner); + GGML_ASSERT(dst->ne[1] == n_t); + GGML_ASSERT(dst->ne[2] == n_s); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + + GGML_ASSERT(src0->nb[1] == src0->ne[0] * static_cast(sizeof(float))); + + const int src_stride_inner = ncs; + const int src_stride_seq = ncs * d_inner; + const int dst_stride_token = d_inner; + const int dst_stride_seq = d_inner * n_t; + + try { + queue *q = ctx.stream(); + + const float *src_data = static_cast(src0->data); + const float *weights = static_cast(src1->data); + float *dst_data = static_cast(dst->data); + + GGML_ASSERT(src_data && weights && dst_data); + + kernel_ssm_conv( + *q, + src_data, + weights, + dst_data, + d_conv, + d_inner, + n_t, + n_s, + ncs, + src_stride_inner, + src_stride_seq, + dst_stride_token, + dst_stride_seq + ); + + } catch (const std::exception &e) { + std::fprintf(stderr, "[SYCL-SSM_CONV] ERROR: %s\n", e.what()); + throw; + } +} diff --git a/ggml/src/ggml-sycl/ssm_conv.hpp b/ggml/src/ggml-sycl/ssm_conv.hpp new file mode 100644 index 0000000000000..1a8ad05f0c7f0 --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_conv.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b783f7805e924..173677a2637a9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -523,7 +523,7 @@ struct vk_device_struct { vk_pipeline pipeline_add_id_f32; vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; - vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32; + vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32; vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_sqr_f32; vk_pipeline pipeline_sqrt_f32; @@ -1238,6 +1238,7 @@ struct vk_op_upscale_push_constants { uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; float sf0; float sf1; float sf2; float sf3; + float pixel_offset; }; struct vk_op_sum_rows_push_constants @@ -3493,7 +3494,6 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); - ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1); ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -7798,14 +7798,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_UPSCALE: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - int mode = ggml_get_op_params_i32(dst, 0); + ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(dst, 0) & 0xFF); switch (mode) { case GGML_SCALE_MODE_NEAREST: return ctx->device->pipeline_upscale_nearest_f32; case GGML_SCALE_MODE_BILINEAR: return ctx->device->pipeline_upscale_bilinear_f32; - case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS: - return ctx->device->pipeline_upscale_bilinear_ac_f32; + default: + return nullptr; } } return nullptr; @@ -9294,22 +9294,26 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0); - float sf0 = (float)dst->ne[0] / src0->ne[0]; - float sf1 = (float)dst->ne[1] / src0->ne[1]; - float sf2 = (float)dst->ne[2] / src0->ne[2]; - float sf3 = (float)dst->ne[3] / src0->ne[3]; + GGML_TENSOR_UNARY_OP_LOCALS + + float sf0 = (float)ne0 / ne00; + float sf1 = (float)ne1 / ne01; + float sf2 = (float)ne2 / ne02; + float sf3 = (float)ne3 / ne03; + float pixel_offset = 0.5f; if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) { - sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); - sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; + pixel_offset = 0.0f; } ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { (uint32_t)ggml_nelements(dst), 0, 0, - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], - (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], - sf0, sf1, sf2, sf3, + (uint32_t)ne00, (uint32_t)ne01, + (uint32_t)nb00 / src0_type_size, (uint32_t)nb01 / src0_type_size, (uint32_t)nb02 / src0_type_size, (uint32_t)nb03 / src0_type_size, + (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, + sf0, sf1, sf2, sf3, pixel_offset }, dryrun); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp index 154a2172d83db..8670aad32c380 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -7,6 +7,7 @@ layout (push_constant) uniform parameter uint nb00; uint nb01; uint nb02; uint nb03; uint ne10; uint ne11; uint ne12; uint ne13; float sf0; float sf1; float sf2; float sf3; + float pixel_offset; } p; #include "types.glsl" @@ -19,7 +20,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; // from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag #define NEAREST 0 #define BILINEAR 1 -#define ALIGN_CORNERS (1 << 8) layout (constant_id = 0) const uint scale_mode = 0; @@ -52,7 +52,7 @@ float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) { float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { const ivec2 ne0 = ivec2(p.ne00, p.ne01); - const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5; + const vec2 c = (vec2(i10, i11) + p.pixel_offset) / vec2(p.sf0, p.sf1) - p.pixel_offset; const vec2 c0f = floor(c); const vec2 d = c - c0f; const ivec2 c0 = max(ivec2(c0f), 0); @@ -61,16 +61,6 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { return fetch_bilinear(c0, c1, d, i12, i13); } -float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) { - const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1); - const vec2 c0f = floor(c); - const vec2 d = c - c0f; - const ivec2 c0 = ivec2(c0f); - const ivec2 c1 = c0 + 1; - - return fetch_bilinear(c0, c1, d, i12, i13); -} - void main() { const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; @@ -91,9 +81,6 @@ void main() { case BILINEAR: result = interpolate_bilinear(i10, i11, i12, i13); break; - case BILINEAR | ALIGN_CORNERS: - result = interpolate_bilinear_align_corners(i10, i11, i12, i13); - break; } data_d[p.d_offset + idx] = D_TYPE(result); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1b71fb3749aaa..94fcfaf69cf09 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3062,6 +3062,7 @@ class VisionProjectorType: VOXTRAL = "voxtral" LFM2 = "lfm2" KIMIVL = "kimivl" + LIGHTONOCR = "lightonocr" # Items here are (block size, type size) diff --git a/models/templates/llama-cpp-lfm2.jinja b/models/templates/llama-cpp-lfm2.jinja new file mode 100644 index 0000000000000..b7921120bc007 --- /dev/null +++ b/models/templates/llama-cpp-lfm2.jinja @@ -0,0 +1,37 @@ +{{- bos_token -}} +{%- set system_prompt = "" -%} +{%- set ns = namespace(system_prompt="") -%} +{%- if messages[0]["role"] == "system" -%} + {%- set ns.system_prompt = messages[0]["content"] -%} + {%- set messages = messages[1:] -%} +{%- endif -%} +{%- if tools -%} + {%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%} + {%- for tool in tools -%} + {%- if tool is not string -%} + {%- set tool = tool | tojson -%} + {%- endif -%} + {%- set ns.system_prompt = ns.system_prompt + tool -%} + {%- if not loop.last -%} + {%- set ns.system_prompt = ns.system_prompt + ", " -%} + {%- endif -%} + {%- endfor -%} + {%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%} +{%- endif -%} +{%- if ns.system_prompt -%} + {{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}} +{%- endif -%} +{%- for message in messages -%} + {{- "<|im_start|>" + message["role"] + "\n" -}} + {%- set content = message["content"] -%} + {%- if content is not string -%} + {%- set content = content | tojson -%} + {%- endif -%} + {%- if message["role"] == "tool" -%} + {%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%} + {%- endif -%} + {{- content + "<|im_end|>\n" -}} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} +{%- endif -%} diff --git a/scripts/snapdragon/adb/run-bench.sh b/scripts/snapdragon/adb/run-bench.sh index 25e0662016cba..b2e651e7493d4 100755 --- a/scripts/snapdragon/adb/run-bench.sh +++ b/scripts/snapdragon/adb/run-bench.sh @@ -35,5 +35,6 @@ adb $adbserial shell " \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ $ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \ - -t 4 --batch-size 128 -ngl 99 $@ \ + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ + --batch-size 128 -ngl 99 $@ \ " diff --git a/scripts/snapdragon/adb/run-cli.sh b/scripts/snapdragon/adb/run-cli.sh index 763482e55ab33..ab8d6d49a24e0 100755 --- a/scripts/snapdragon/adb/run-cli.sh +++ b/scripts/snapdragon/adb/run-cli.sh @@ -45,8 +45,9 @@ adb $adbserial shell " \ cd $basedir; ulimit -c unlimited; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $verbose $experimental $sched $opmask $profile $nhvx $ndev \ - ./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \ - -t 4 --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \ + $verbose $experimental $sched $opmask $profile $nhvx $ndev \ + ./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \ + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ + --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \ -ngl 99 --device $device $cli_opts $@ \ " diff --git a/src/llama-context.cpp b/src/llama-context.cpp index bd348bcad370a..f6192a36e0ee5 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -268,9 +268,7 @@ llama_context::llama_context( if (pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); } - } - if (!hparams.vocab_only) { llama_memory_context_ptr mctx; if (memory) { LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); @@ -343,7 +341,14 @@ llama_context::llama_context( { auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); + if (pipeline_parallel) { + LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); + gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + } + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } } n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 736693e174527..add74391f0c47 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache( const uint32_t n_layer_kv = hparams.n_layer_kv(); + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + std::map ctx_map; + // create a context for each buffer type - std::map ctx_map; auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { @@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache( return nullptr; } - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); + ctx_map.emplace(buft, ctx); return ctx; } - return it->second; + return it->second.get(); }; GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max); @@ -167,11 +174,8 @@ llama_kv_cache::llama_kv_cache( } // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); if (!buf) { throw std::runtime_error("failed to allocate buffer for kv cache"); } @@ -179,7 +183,7 @@ llama_kv_cache::llama_kv_cache( LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); ggml_backend_buffer_clear(buf, 0); - bufs.emplace_back(buf); + ctxs_bufs.emplace_back(std::move(ctx), buf); } { @@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) { } if (data) { - for (auto & buf : bufs) { + for (auto & [_, buf] : ctxs_bufs) { ggml_backend_buffer_clear(buf.get(), 0); } } @@ -472,8 +476,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { std::map llama_kv_cache::memory_breakdown() const { std::map ret; - for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { - ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + for (const auto & [_, buf] : ctxs_bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); } return ret; } @@ -1298,7 +1302,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch size_t llama_kv_cache::total_size() const { size_t size = 0; - for (const auto & buf : bufs) { + for (const auto & [_, buf] : ctxs_bufs) { size += ggml_backend_buffer_get_size(buf.get()); } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 85f0663d8c1d4..150e282596255 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -217,8 +217,8 @@ class llama_kv_cache : public llama_memory_i { // this is the SWA type of the cache - not to be confused with the model SWA type const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; - std::vector ctxs; - std::vector bufs; + // ggml contexts for the KV cache along with the allocated backend buffers: + std::vector> ctxs_bufs; // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) // note: this is not part of the KV state and it's only used to speed-up the find_slot() method diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index d67f5a5f47b87..276e1697d466c 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent( cells.clear(); cells.resize(mem_size); + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + std::map ctx_map; + // create a context for each buffer type - std::map ctx_map; auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { @@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent( return nullptr; } - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); + ctx_map.emplace(buft, ctx); return ctx; } - return it->second; + return it->second.get(); }; r_l.resize(n_layer); @@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent( } // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); if (!buf) { throw std::runtime_error("failed to allocate buffer for rs cache"); } ggml_backend_buffer_clear(buf, 0); LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - bufs.emplace_back(buf); + ctxs_bufs.emplace_back(std::move(ctx), buf); } { @@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) { used = 0; if (data) { - for (auto & buf : bufs) { + for (auto & [_, buf] : ctxs_bufs) { ggml_backend_buffer_clear(buf.get(), 0); } } @@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; - for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { - ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + for (const auto & [_, buf] : ctxs_bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); } return ret; } @@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const { size_t llama_memory_recurrent::total_size() const { size_t size = 0; - for (const auto & buf : bufs) { + for (const auto & [_, buf] : ctxs_bufs) { size += ggml_backend_buffer_get_size(buf.get()); } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 077c6e3ce938d..47f01d7391248 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -109,8 +109,8 @@ class llama_memory_recurrent : public llama_memory_i { const uint32_t n_seq_max = 1; - std::vector ctxs; - std::vector bufs; + // ggml contexts for the KV cache along with the allocated backend buffers: + std::vector> ctxs_bufs; size_t total_size() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 05e467180089e..bb83a04e96055 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2231,7 +2231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // define a comparator for the buft -> ctx map to ensure that the order is well-defined: struct ggml_backend_buft_comparator { bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { - return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs); + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; } }; std::map ctx_map; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2e2a87ac4f518..aee1730137900 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7049,6 +7049,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5, 7, 11}, mode)); } test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum_rows()); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 52e23b5ac61f5..4a8ba849b3f8c 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -16,6 +16,7 @@ #include #include +#include #include using json = nlohmann::ordered_json; @@ -2138,6 +2139,154 @@ static void test_template_output_parsers() { assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get())); } + { + // LFM2 format tests + auto tmpls = read_templates("models/templates/llama-cpp-lfm2.jinja"); + std::vector end_tokens{ "<|im_end|>" }; + + auto inputs_tools_forced_json_schema = std::invoke([&]() -> common_chat_templates_inputs { + common_chat_templates_inputs inputs; + inputs.messages = { + std::invoke([&]() -> common_chat_msg { + common_chat_msg msg; + msg.role = "system"; + msg.content = "force json schema.\n"; + return msg; + }), + message_user, + }; + inputs.tools = {special_function_tool}; + return inputs; + }); + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_no_tools); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format); + assert_equals(false, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + } + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_tools); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format); + assert_equals(false, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>system +List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|> +<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + assert_equals(true, params.grammar.empty()); + } + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_tools_forced_json_schema); + assert_equals(COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, params.format); + assert_equals(true, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>system +List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|> +<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + assert_equals(false, params.grammar.empty()); + } + + // Test parsing regular content + assert_msg_equals(message_assist, + common_chat_parse( + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test single tool call with JSON format + common_chat_msg msg_single_tool_call; + msg_single_tool_call.role = "assistant"; + msg_single_tool_call.tool_calls.push_back({"special_function", "{\"arg1\":1}", ""}); + assert_msg_equals( + msg_single_tool_call, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with string argument + common_chat_msg msg_tool_call_string; + msg_tool_call_string.role = "assistant"; + msg_tool_call_string.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_tool_call_string, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with multiple arguments + common_chat_msg msg_multi_args; + msg_multi_args.role = "assistant"; + msg_multi_args.tool_calls.push_back({"calculate", "{\"x\":10,\"y\":20,\"operation\":\"add\"}", ""}); + assert_msg_equals( + msg_multi_args, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"calculate\", \"arguments\": {\"x\": 10, \"y\": 20, \"operation\": \"add\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test multiple tool calls in single array + common_chat_msg msg_multiple_tools; + msg_multiple_tools.role = "assistant"; + msg_multiple_tools.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + msg_multiple_tools.tool_calls.push_back({"get_time", "{\"timezone\":\"UTC\"}", ""}); + assert_msg_equals( + msg_multiple_tools, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}, {\"name\": \"get_time\", \"arguments\": {\"timezone\": \"UTC\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with content before + common_chat_msg msg_content_before_tool; + msg_content_before_tool.role = "assistant"; + msg_content_before_tool.content = "Let me check the weather for you."; + msg_content_before_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_content_before_tool, + common_chat_parse( + "Let me check the weather for you.<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with content after + common_chat_msg msg_content_after_tool; + msg_content_after_tool.role = "assistant"; + msg_content_after_tool.content = "Here's the result."; + msg_content_after_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_content_after_tool, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>Here's the result.", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with newlines (common in LLM output) + common_chat_msg msg_tool_call_newlines; + msg_tool_call_newlines.role = "assistant"; + msg_tool_call_newlines.tool_calls.push_back({"get_current_time", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_tool_call_newlines, + common_chat_parse( + "<|tool_call_start|>[{\n \"name\": \"get_current_time\",\n \"arguments\": {\n \"location\": \"Paris\"\n }\n}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Note: LFM2 uses JSON format for tool calls: [{"name": "...", "arguments": {...}}] + // Unlike other formats, LFM2 template does not render tool calls in conversation history, + // so we don't use test_templates() for tool call generation. Instead, the parsing tests + // above verify edge cases and format variations for the tool call output format. + } } diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 67df240c6fef3..8a55bc54ae466 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1124,9 +1124,9 @@ static void test_all(const std::string & lang, std::function PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, { PROJECTOR_TYPE_LFM2, "lfm2"}, { PROJECTOR_TYPE_KIMIVL, "kimivl"}, + { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f2abf88523843..b44f0a3a28ad2 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -171,7 +171,7 @@ struct clip_hparams { int32_t n_head; int32_t n_layer; // idefics3 - int32_t preproc_image_size = 0; + int32_t preproc_image_size = 0; // aka max_dimension int32_t proj_scale_factor = 0; float image_mean[3]; @@ -621,7 +621,7 @@ struct clip_graph { } // arrangement of the [IMG_BREAK] token - { + if (model.token_embd_img_break) { // not efficient, but works // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows] // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension @@ -2095,6 +2095,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 res = graph.build_siglip(); } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { res = graph.build_pixtral(); } break; @@ -2380,6 +2381,7 @@ struct clip_model_loader { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { hparams.rope_theta = 10000.0f; hparams.warmup_image_size = hparams.patch_size * 8; @@ -2722,6 +2724,15 @@ struct clip_model_loader { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); } break; + case PROJECTOR_TYPE_LIGHTONOCR: + { + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); + } break; case PROJECTOR_TYPE_ULTRAVOX: { model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); @@ -3210,8 +3221,8 @@ struct image_manipulation { return {0, 0}; } - float scale = std::min(1.0f, std::min(static_cast(max_dimension) / inp_size.width, - static_cast(max_dimension) / inp_size.height)); + float scale = std::min(static_cast(max_dimension) / inp_size.width, + static_cast(max_dimension) / inp_size.height); float target_width_f = static_cast(inp_size.width) * scale; float target_height_f = static_cast(inp_size.height) * scale; @@ -3374,7 +3385,7 @@ struct llava_uhd { // resize to overview size clip_image_u8_ptr resized_img(clip_image_u8_init()); - image_manipulation::bicubic_resize(*img, *resized_img, inst.overview_size.width, inst.overview_size.height); + image_manipulation::resize_and_pad_image(*img, *resized_img, inst.overview_size); output.push_back(std::move(resized_img)); if (inst.slices.empty()) { // no slices, just return the resized image @@ -3576,6 +3587,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737 const clip_image_size refined_size = image_manipulation::calc_size_preserved_ratio( original_size, params.image_size, params.preproc_image_size); + // LOG_INF("%s: original size: %d x %d, refined size: %d x %d\n", + // __func__, original_size.width, original_size.height, + // refined_size.width, refined_size.height); llava_uhd::slice_instructions instructions; instructions.overview_size = clip_image_size{params.image_size, params.image_size}; @@ -3586,6 +3600,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str }; for (int y = 0; y < refined_size.height; y += params.image_size) { for (int x = 0; x < refined_size.width; x += params.image_size) { + // LOG_INF("%s: adding slice at x=%d, y=%d\n", __func__, x, y); instructions.slices.push_back(llava_uhd::slice_coordinates{ /* x */x, /* y */y, @@ -3622,7 +3637,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(img_f32)); return true; - } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL + || ctx->proj_type() == PROJECTOR_TYPE_LIGHTONOCR + ) { clip_image_u8 resized_image; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size); image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height); @@ -3865,12 +3882,17 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches = x_patch * y_patch; } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { // dynamic size int n_merge = params.spatial_merge_size; int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1); int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1); - n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row + if (ctx->model.token_embd_img_break) { + n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row + } else { + n_patches = n_patches_y * n_patches_x; + } } break; case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: @@ -4247,6 +4269,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_LIGHTONOCR: { // set the 2D positions int n_patches_per_col = image_size_width / patch_size; @@ -4377,6 +4400,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_peg_0_b->ne[0]; case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_MLP_NORM: return ctx->model.mm_3_b->ne[0]; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 4d487581ae0a0..3b901bfac8215 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -275,6 +275,11 @@ struct mtmd_context { img_beg = ""; img_end = ""; + } else if (proj == PROJECTOR_TYPE_LIGHTONOCR) { + // <|im_start|> ... (image embeddings) ... <|im_end|> + img_beg = "<|im_start|>"; + img_end = "<|im_end|>"; + } } diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index dbdf7656a66d9..c2270746360ec 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -70,6 +70,7 @@ add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0" add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0" add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0" +add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" @@ -138,7 +139,10 @@ for i in "${!arr_hf[@]}"; do echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log - if echo "$output" | grep -iq "new york"; then + # either contains "new york" or both "men" and "walk" + if echo "$output" | grep -iq "new york" \ + || (echo "$output" | grep -iq "men" && echo "$output" | grep -iq "walk") + then result="$prefix \033[32mOK\033[0m: $bin $hf" else result="$prefix \033[31mFAIL\033[0m: $bin $hf" diff --git a/tools/server/public_legacy/json-schema-to-grammar.mjs b/tools/server/public_legacy/json-schema-to-grammar.mjs index 6f0952974496a..1d9dc5105eee9 100644 --- a/tools/server/public_legacy/json-schema-to-grammar.mjs +++ b/tools/server/public_legacy/json-schema-to-grammar.mjs @@ -345,10 +345,14 @@ export class SchemaConverter { const selectors = ref.split('#')[1].split('/').slice(1); for (const sel of selectors) { - if (!target || !(sel in target)) { + const selIndex = parseInt(sel, 10); + if (target && sel in target) { + target = target[sel]; + } else if (target && selIndex in target) { + target = target[selIndex]; + } else { throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`); } - target = target[sel]; } this._refs[ref] = target; @@ -594,7 +598,8 @@ export class SchemaConverter { } _resolveRef(ref) { - let refName = ref.split('/').pop(); + let refFragment = ref.split('#').pop(); + let refName = 'ref' + refFragment.replace(/[^a-zA-Z0-9-]+/g, '-'); if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) { this._refsBeingResolved.add(ref); const resolved = this._refs[ref];