diff --git a/jina_embeddings/README.md b/jina_embeddings/README.md new file mode 100644 index 0000000000000..6a4f550a95c10 --- /dev/null +++ b/jina_embeddings/README.md @@ -0,0 +1,14 @@ +# Inference example +```bash +python infer.py \ + --llama-bin /home/andrei/workspace/llama.cpp/build/bin/llama-server \ + --model /home/andrei/workspace/gguf/jev4-bf16.gguf \ + --mmproj /home/andrei/workspace/gguf/mmproj-jev4-bf16.gguf \ + --gpus 7 \ + --input /home/andrei/workspace/test_data.txt \ + --output /home/andrei/workspace/jev4_mmtd.json \ + --save-cosine-sim-path /home/andrei/workspace/jev4_mmtd.md \ + --query-prefix "Query: " \ + --document-prefix "Passage: " \ + --normalize-after-pooling +``` \ No newline at end of file diff --git a/jina_embeddings/infer.py b/jina_embeddings/infer.py new file mode 100644 index 0000000000000..1241560ce4ec6 --- /dev/null +++ b/jina_embeddings/infer.py @@ -0,0 +1,144 @@ +import json +import os +import signal +import subprocess + +import click # type: ignore +import numpy as np # type: ignore +from sklearn.metrics.pairwise import cosine_similarity # type: ignore + +from model import LlamaCppServerEmbeddingModel + + +def clip_text(text: str, max_len: int = 10) -> str: + """Clip text to max_len characters, showing first part + '...' if needed""" + if len(text) <= max_len: + return text + return text[:max_len-3] + "..." + + +def save_cosine_similarity_matrix(raw_lines: list[str], embeddings: np.ndarray, save_path: str) -> None: + """Save cosine similarity matrix as markdown table""" + # Extract display names from original texts + display_names = [] + for text in raw_lines: + if text.startswith('[QUERY] '): + content = text[8:] + display_names.append(f"Q:{clip_text(content)}") + elif text.startswith('[DOCUMENT] '): + content = text[11:] + display_names.append(f"D:{clip_text(content)}") + elif text.startswith('[IMAGE] '): + image_path = text[8:] + filename = os.path.basename(image_path) + display_names.append(f"I:{clip_text(filename)}") + else: + display_names.append(clip_text(text)) + + # Compute cosine similarity matrix + similarity_matrix = cosine_similarity(embeddings) + + # Create markdown table + with open(save_path, 'w', encoding='utf-8') as f: + f.write("# Cosine Similarity Matrix\n\n") + + # Write header row + f.write("| Item |") + for name in display_names: + f.write(f" {name} |") + f.write("\n") + + # Write separator row + f.write("|" + "---|" * (len(display_names) + 1) + "\n") + + # Write data rows + for i, row_name in enumerate(display_names): + f.write(f"| {row_name} |") + for j in range(len(display_names)): + sim_score = similarity_matrix[i, j] + f.write(f" {sim_score:.3f} |") + f.write("\n") + + print(f"Saved cosine similarity matrix to {save_path}") + + +@click.command() +@click.option('--llama-bin', default='./llama-server', help='Path to llama-server binary') +@click.option('--model', required=True, help='Path to model .gguf file') +@click.option('--mmproj', required=True, help='Path to mmproj .gguf file') +@click.option('--port', default=8080, help='Port for llama-server') +@click.option('--host', default='0.0.0.0', help='Host for llama-server') +@click.option('--ngl', default=999, help='Number of GPU layers') +@click.option('--gpus', default='0', help='CUDA_VISIBLE_DEVICES comma separated GPU ids (e.g. "0,1")') +@click.option('--input', 'input_path', required=True, help='Path to input txt file. Format: "[TYPE] content" where TYPE is QUERY, DOCUMENT, or IMAGE. For IMAGE, content should be the file path.') +@click.option('--output', 'output_path', required=True, help='Path to output JSON file for embeddings') +@click.option('--normalize-after-pooling', is_flag=True, default=False, help='Apply L2 normalization after pooling') +@click.option('--save-cosine-sim-path', help='Path to save cosine similarity matrix as markdown table') +@click.option('--query-prefix', default='Query: ', help='Prefix for [QUERY] lines') +@click.option('--document-prefix', default='Passage: ', help='Prefix for [DOCUMENT] lines') +@click.option('--image-prefix', default='Describe the image.<__image__>', help='Prefix for [IMAGE] lines') +def main( + llama_bin, model, mmproj, port, host, ngl, gpus, + input_path, output_path, + normalize_after_pooling, + save_cosine_sim_path, query_prefix, document_prefix, image_prefix +): + env = os.environ.copy() + env['CUDA_VISIBLE_DEVICES'] = gpus + + cmd = [ + llama_bin, + '-m', model, + '--mmproj', mmproj, + '--embedding', + '--port', str(port), + '-ngl', str(ngl), + '--host', host, + '--pooling', 'none' + ] + print(f"Starting llama-server with: {' '.join(cmd)}") + proc = subprocess.Popen(cmd, env=env) + + try: + with open(input_path, 'r', encoding='utf-8') as f: + raw_lines = [line.strip() for line in f if line.strip()] + + print(f"Loaded {len(raw_lines)} sentences from {input_path}") + + model = LlamaCppServerEmbeddingModel( + server_url=f"http://{host}:{port}", + normalize_after_pooling=normalize_after_pooling, + query_prefix=query_prefix, + document_prefix=document_prefix, + image_prefix=image_prefix + ) + + model.wait_for_server() + original_texts, embeddings = model.encode_from_lines(raw_lines) + + output_data = [ + {"text": text, "embedding": embedding.tolist()} + for text, embedding in zip(original_texts, embeddings) + ] + + with open(output_path, 'w', encoding='utf-8') as f_out: + json.dump(output_data, f_out, indent=2) + + print(f"Saved embeddings to {output_path}") + + # Save cosine similarity matrix if requested + if save_cosine_sim_path: + save_cosine_similarity_matrix(raw_lines, embeddings, save_cosine_sim_path) + + finally: + print("Shutting down server...") + proc.send_signal(signal.SIGINT) + try: + proc.wait(timeout=10) + except subprocess.TimeoutExpired: + print("Server did not shut down in time; killing process.") + proc.kill() + + +if __name__ == '__main__': + main() # type: ignore \ No newline at end of file diff --git a/jina_embeddings/model.py b/jina_embeddings/model.py new file mode 100644 index 0000000000000..3f71b7dd40db8 --- /dev/null +++ b/jina_embeddings/model.py @@ -0,0 +1,167 @@ +import base64 +import os +import time +from typing import List, Optional, Tuple + +import numpy as np # type: ignore +import requests # type: ignore +from typing_extensions import TypedDict # type: ignore + + +class EmbeddingRequestItem(TypedDict): + content: str + image: Optional[str] + + +class LlamaCppServerEmbeddingModel: + def __init__( + self, + server_url: str = "http://localhost:8080", + normalize_after_pooling: bool = False, + query_prefix: str = "Query: ", + document_prefix: str = "Passage: ", + image_prefix: str = "Describe the image.<__image__>" + ) -> None: + self.server_url = server_url + self.normalize_after_pooling = normalize_after_pooling + self.query_prefix = query_prefix + self.document_prefix = document_prefix + self.image_prefix = image_prefix + + def wait_for_server(self, max_wait_time: int = 300, check_interval: int = 2) -> None: + """Wait for the server to be ready""" + print("Waiting for server to start...") + test_payload = {"content": "test"} + + start_time = time.time() + while True: + elapsed = time.time() - start_time + if elapsed > max_wait_time: + raise TimeoutError(f"Server did not become ready within {max_wait_time} seconds") + try: + r = requests.post(f"{self.server_url}/embedding", json=test_payload, timeout=10) + assert r.status_code == 200, f"Server not ready: {r.status_code}" + print("✅ Server is ready!") + break + except (requests.exceptions.RequestException, AssertionError): + print(f"⏳ Waiting for server to start... ({elapsed:.1f}s elapsed)") + time.sleep(check_interval) + + def _parse_line(self, line: str) -> Tuple[str, EmbeddingRequestItem]: + """Parse input line and return (original_content, EmbeddingRequestItem)""" + if line.startswith('[QUERY] '): + content = line[8:] # Remove '[QUERY] ' + item: EmbeddingRequestItem = { "content": self.query_prefix + content, "image": None } + return content, item + elif line.startswith('[DOCUMENT] '): + content = line[11:] # Remove '[DOCUMENT] ' + item: EmbeddingRequestItem = { "content": self.document_prefix + content, "image": None } + return content, item + elif line.startswith('[IMAGE] '): + image_path = line[8:] # Remove '[IMAGE] ' + data_url, success = self._process_image(image_path) + assert success, f"Failed to process image: {image_path}" + item: EmbeddingRequestItem = { "content": self.image_prefix, "image": data_url } + return image_path, item + else: + raise ValueError(f"Invalid line format: {line}. Expected '[QUERY] ', '[DOCUMENT] ', or '[IMAGE] ' prefix.") + + def _process_image(self, image_path: str) -> Tuple[Optional[str], bool]: + """Process image file and return (data_url, success)""" + try: + with open(image_path, 'rb') as img_file: + image_data = base64.b64encode(img_file.read()).decode('utf-8') + + # Detect image format from extension + ext = os.path.splitext(image_path)[1].lower() + if ext in ['.jpg', '.jpeg']: + mime_type = 'image/jpeg' + elif ext == '.png': + mime_type = 'image/png' + elif ext == '.webp': + mime_type = 'image/webp' + else: + mime_type = 'image/jpeg' # default + + data_url = f"data:{mime_type};base64,{image_data}" + return data_url, True + + except FileNotFoundError: + print(f"❌ Image not found: {image_path}, processing as text only") + return None, False + + def encode(self, items: List[EmbeddingRequestItem]) -> np.ndarray: + """ + Encode items. Each item should be an EmbeddingRequestItem. + """ + embeddings = [] + + for i, item in enumerate(items): + payload = {"content": item["content"]} + if item["image"]: + payload["image"] = item["image"] + + is_image_request = item["image"] is not None + response = requests.post(f"{self.server_url}/embedding", json=payload) + assert response.status_code == 200, f"Server error: {response.text}" + embedding_data = response.json() + raw_embedding = embedding_data["embedding"] + + # TODO: optional enable logging via argument + print(f"\n==========================") + print(f"🧠 Item {i + 1} embedding response") + print(f"📦 Type: {type(embedding_data).__name__}") + print(f"🔑 Keys: {list(embedding_data.keys())}") + print(f"🔎 Preview: {repr(embedding_data)[:500]}") + print(f"🔍 Raw embedding type: {type(raw_embedding)}") + print(f"🔍 Raw embedding shape: {np.array(raw_embedding).shape}") + print(f"==========================") + + # Check if embeddings are already normalized + embedding_array = np.array(raw_embedding) + norms = np.linalg.norm(embedding_array, axis=1) + if np.allclose(norms, 1.0, atol=1e-6): + print(f"⚠️ WARNING: Raw embeddings appear to be already normalized!") + + # Handle image token extraction + if is_image_request: + start_idx = embedding_data["start_image_token_idx"] + end_idx = embedding_data["end_image_token_idx"] + hidden_states = np.array(raw_embedding) + image_embeddings = hidden_states[start_idx:end_idx+1] # +1 for inclusive end + pooled = image_embeddings.mean(axis=0) + print(f"🖼️ Image token indices: start={start_idx}, end={end_idx}") + print(f"🖼️ Extracted image embeddings shape: {image_embeddings.shape}") + print(f"🖼️ Original total embeddings: {len(raw_embedding)}") + print(f"🖼️ Image embeddings extracted: {len(image_embeddings)}") + else: + # Regular text processing - always mean pool the tokens + hidden_states = np.array(raw_embedding) + pooled = hidden_states.mean(axis=0) + + # Optional normalization + if self.normalize_after_pooling: + norm = np.linalg.norm(pooled) + if norm > 0: + pooled = pooled / norm + print(f"🔄 Applied L2 normalization") + + embeddings.append(pooled) + + return np.array(embeddings) + + def encode_from_lines(self, raw_lines: List[str]) -> Tuple[List[str], np.ndarray]: + """ + Process raw lines with type prefixes and return embeddings along with original content + Returns: (original_texts, embeddings) + """ + original_texts = [] + items = [] + + for line in raw_lines: + original, item = self._parse_line(line.strip()) + original_texts.append(original) + items.append(item) + + embeddings = self.encode(items) + return original_texts, embeddings \ No newline at end of file diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 022b5d0b31034..e840e3d20dd59 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -491,6 +491,7 @@ struct server_task { } } } + // set reverse prompt from cli args if not set in the request if (params.antiprompt.empty()) { params.antiprompt = defaults.antiprompt; @@ -1044,6 +1045,10 @@ struct server_task_result_embd : server_task_result { std::vector> embedding; int32_t n_tokens; + int32_t start_image_token_idx = -1; // -1 means no image + int32_t end_image_token_idx = -1; + + bool has_stored_embeddings = false; // OAI-compat fields oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; @@ -1059,18 +1064,34 @@ struct server_task_result_embd : server_task_result { } json to_json_non_oaicompat() { - return json { + json result = json { {"index", index}, {"embedding", embedding}, }; + + // Add image indices if this was a multimodal request + if (start_image_token_idx != -1) { + result["start_image_token_idx"] = start_image_token_idx; + result["end_image_token_idx"] = end_image_token_idx; + } + + return result; } json to_json_oaicompat() { - return json { + json result = json { {"index", index}, {"embedding", embedding[0]}, {"tokens_evaluated", n_tokens}, }; + + // Add image indices for OAI-compat too + if (start_image_token_idx != -1) { + result["start_image_token_idx"] = start_image_token_idx; + result["end_image_token_idx"] = end_image_token_idx; + } + + return result; } }; @@ -1302,6 +1323,11 @@ struct server_slot { std::vector generated_token_probs; + // Fields for storing embeddings when processing multi-modal inputs + std::vector> stored_pre_image_embeddings; + std::vector> stored_image_embeddings; + bool has_stored_embeddings = false; + bool has_next_token = true; bool has_new_line = false; bool truncated = false; @@ -1355,6 +1381,11 @@ struct server_slot { json_schema = json(); generated_tool_call_ids.clear(); + // *** NEW: Clear multimodal embedding storage *** + stored_pre_image_embeddings.clear(); + stored_image_embeddings.clear(); + has_stored_embeddings = false; + // clear speculative decoding stats n_draft_total = 0; n_draft_accepted = 0; @@ -2570,6 +2601,10 @@ struct server_context { } void send_embedding(const server_slot & slot, const llama_batch & batch) { + printf("=== send_embedding DEBUG ===\n"); + printf("batch.n_tokens = %d\n", batch.n_tokens); + printf("has_stored_embeddings = %s\n", slot.has_stored_embeddings ? "true" : "false"); + auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2577,39 +2612,119 @@ struct server_context { res->oaicompat = slot.params.oaicompat; const int n_embd = llama_model_n_embd(model); - - std::vector embd_res(n_embd, 0.0f); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; + printf("n_embd = %d\n", n_embd); + + int pooling_type = llama_pooling_type(slot.ctx); + printf("pooling_type = %d (NONE=%d)\n", pooling_type, LLAMA_POOLING_TYPE_NONE); + + if (slot.has_stored_embeddings) { + // *** MULTIMODAL EMBEDDING ASSEMBLY *** + printf("ASSEMBLY: Multimodal embedding - combining all parts\n"); + + // Part 1: Pre-image text embeddings + for (const auto& pre_embd : slot.stored_pre_image_embeddings) { + res->embedding.push_back(pre_embd); } - const float * embd = nullptr; - if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { - embd = llama_get_embeddings_ith(ctx, i); - } else { - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + printf("ASSEMBLY: Added %zu pre-image embeddings\n", slot.stored_pre_image_embeddings.size()); + + // *** SET IMAGE START INDEX *** + res->start_image_token_idx = slot.stored_pre_image_embeddings.size(); + + // Part 2: Image embeddings + for (const auto& img_embd : slot.stored_image_embeddings) { + res->embedding.push_back(img_embd); } - if (embd == nullptr) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->embedding.push_back(std::vector(n_embd, 0.0f)); - continue; + printf("ASSEMBLY: Added %zu image embeddings\n", slot.stored_image_embeddings.size()); + + // *** SET IMAGE END INDEX *** + res->end_image_token_idx = slot.stored_pre_image_embeddings.size() + slot.stored_image_embeddings.size() - 1; + + // Part 3: Post-image text embeddings (current batch) + if (pooling_type != LLAMA_POOLING_TYPE_NONE) { + printf("ASSEMBLY: Using sequence-level pooling for post-image text\n"); + const float * embd = llama_get_embeddings_seq(ctx, slot.id); + if (embd != nullptr) { + std::vector embd_res(n_embd); + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + printf("ASSEMBLY: Added pooled post-image embedding\n"); + } else { + printf("ASSEMBLY: ERROR - llama_get_embeddings_seq returned NULL for post-image!\n"); + } + } else { + printf("ASSEMBLY: Using token-level embeddings for post-image text\n"); + int post_image_embeddings = 0; + + // For multimodal, we need to get embeddings from the current batch + for (int pos = 0; pos < batch.n_tokens; ++pos) { + const float * embd = llama_get_embeddings_ith(ctx, pos); + if (embd != nullptr) { + res->embedding.emplace_back(embd, embd + n_embd); + post_image_embeddings++; + if (pos < 3) { + printf("ASSEMBLY: Found post-image embedding at batch pos %d\n", pos); + } + } else { + if (pos < 3) { + printf("ASSEMBLY: No post-image embedding at batch pos %d\n", pos); + } + } + } + printf("ASSEMBLY: Added %d post-image embeddings from current batch\n", post_image_embeddings); } + + printf("ASSEMBLY: Total multimodal embeddings: %zu (pre:%zu + img:%zu + post:current)\n", + res->embedding.size(), + slot.stored_pre_image_embeddings.size(), + slot.stored_image_embeddings.size()); + + printf("ASSEMBLY: Image token indices: start=%d, end=%d\n", + res->start_image_token_idx, res->end_image_token_idx); - // normalize only when there is pooling - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, 2); - res->embedding.push_back(embd_res); - break; + } else { + // *** REGULAR TEXT-ONLY EMBEDDING (UNCHANGED LOGIC) *** + printf("ASSEMBLY: Text-only embedding - using existing logic\n"); + + if (pooling_type != LLAMA_POOLING_TYPE_NONE) { + printf("Using sequence-level pooling\n"); + // Sequence-level pooling - get the pooled embedding for the entire sequence + const float * embd = llama_get_embeddings_seq(ctx, slot.id); + printf("llama_get_embeddings_seq returned: %p\n", (void*)embd); + + if (embd != nullptr) { + std::vector embd_res(n_embd); + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + printf("Added pooled embedding, size = %zu\n", embd_res.size()); + } else { + printf("ERROR: llama_get_embeddings_seq returned NULL!\n"); + } } else { - res->embedding.emplace_back(embd, embd + n_embd); + printf("Using token-level embeddings\n"); + // Token-level embeddings - get embeddings for each position in the sequence + int embeddings_found = 0; + for (int pos = 0; pos < slot.n_past; ++pos) { + const float * embd = llama_get_embeddings_ith(ctx, pos); + if (embd != nullptr) { + res->embedding.emplace_back(embd, embd + n_embd); + embeddings_found++; + if (pos < 5 || pos >= slot.n_past - 5) { + printf("Found embedding at pos %d\n", pos); + } + } else { + if (pos < 5 || pos >= slot.n_past - 5) { + printf("No embedding at pos %d\n", pos); + } + } + } + printf("Total embeddings found: %d out of %d positions\n", embeddings_found, slot.n_past); } } - SLT_DBG(slot, "%s", "sending embeddings\n"); + printf("Final embedding count: %zu\n", res->embedding.size()); + printf("=== send_embedding END ===\n"); queue_results.send(std::move(res)); } @@ -3301,10 +3416,36 @@ struct server_context { // check if we should process the image if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + printf("=== IMAGE PROCESSING DEBUG ===\n"); + printf("Before process_chunk: slot.n_past = %d, slot.n_prompt_tokens = %d\n", + slot.n_past, slot.n_prompt_tokens); + // process the image int32_t new_n_past; int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + + // CAPTURE IMAGE EMBEDDINGS immediately after process_chunk() + if (res == 0 && slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + const int n_embd = llama_model_n_embd(model); + int image_embeddings_found = 0; + + for (int batch_pos = 0; batch_pos < 512; batch_pos++) { // reasonable upper bound + const float * embd = llama_get_embeddings_ith(ctx, batch_pos); + if (embd != nullptr) { + slot.stored_image_embeddings.emplace_back(embd, embd + n_embd); + image_embeddings_found++; + } else { + break; // Stop at first nullptr + } + } + + printf("STORAGE: Captured %d image embeddings dynamically\n", image_embeddings_found); + slot.has_stored_embeddings = true; + } + int32_t n_pos = new_n_past - slot.n_past; + printf("process_chunk result: res = %d, old_n_past = %d, new_n_past = %d, n_pos = %d\n", + res, slot.n_past, new_n_past, n_pos); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); @@ -3321,19 +3462,18 @@ struct server_context { slot.n_past += n_pos; slot.n_prompt_tokens_processed += n_pos; + + printf("=== IMAGE PROCESSING END ===\n"); } // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { - // get next token to process llama_token cur_tok = slot.prompt_tokens[slot.n_past]; if (cur_tok == LLAMA_TOKEN_NULL) { break; // end of text chunk } - // embedding requires all tokens in the batch to be output const bool need_embd = server_task_type_need_embd(slot.task_type); - common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); slot.cache_tokens.push_back(cur_tok); @@ -3410,6 +3550,31 @@ struct server_context { const int ret = llama_decode(ctx, batch_view); + // NOTE: Added this embedding capture to store emebeddings retrieved before image + if (ret == 0) { + for (auto & slot : slots) { + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + // Capture pre-image embeddings (only before image processing) + if (!slot.has_stored_embeddings) { + const int n_embd = llama_model_n_embd(model); + int pre_image_captured = 0; + + for (int batch_pos = 0; batch_pos < batch_view.n_tokens; batch_pos++) { + const float * embd = llama_get_embeddings_ith(ctx, batch_pos); + if (embd != nullptr) { + slot.stored_pre_image_embeddings.emplace_back(embd, embd + n_embd); + pre_image_captured++; + } + } + + if (pre_image_captured > 0) { + printf("PRE-IMAGE CAPTURE: Stored %d pre-image embeddings\n", pre_image_captured); + } + } + } + } + } + metrics.on_decoded(slots); if (ret != 0) { @@ -4521,10 +4686,9 @@ int main(int argc, char ** argv) { json tokens_response = json::array(); if (body.count("content") != 0) { const bool add_special = json_value(body, "add_special", false); - const bool parse_special = json_value(body, "parse_special", true); const bool with_pieces = json_value(body, "with_pieces", false); - llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true); if (with_pieces) { for (const auto& token : tokens) { @@ -4569,105 +4733,176 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { - if (!ctx_server.params_base.embedding) { - res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); - return; - } + // new embeddings implementation + const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok]( + const httplib::Request & req, + httplib::Response & res, + const std::vector & files, + oaicompat_type oaicompat) -> void { + + const json data = json::parse(req.body); + const auto & prompt = oaicompat ? data.at("prompt") : data.at("content"); + + printf("EMBEDDINGS: Processing prompt with %zu files\n", files.size()); + + // Process files + mtmd::bitmaps bitmaps; + const bool has_mtmd = ctx_server.mctx != nullptr; + + printf("EMBEDDINGS: Multimodal context available: %s\n", has_mtmd ? "YES" : "NO"); + + if (!has_mtmd && !files.empty()) { + throw std::runtime_error("This server does not support multimodal"); + } + + for (size_t i = 0; i < files.size(); i++) { + printf("EMBEDDINGS: Processing file %zu, size: %zu bytes\n", i, files[i].size()); + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx_server.mctx, files[i].data(), files[i].size())); + if (!bmp.ptr) { + throw std::runtime_error("Failed to load image or audio file"); + } + std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); + bmp.set_id(hash.c_str()); + printf("EMBEDDINGS: File %zu processed, hash: %s\n", + i, hash.c_str()); + bitmaps.entries.push_back(std::move(bmp)); + } + + // Process prompt + std::vector inputs; + + if (has_mtmd && !files.empty()) { + // multimodal tokenization + std::string prompt_str; + if (prompt.is_string()) { + prompt_str = prompt.get(); + } else { + prompt_str = prompt.dump(); + } + + printf("EMBEDDINGS: Tokenizing multimodal prompt: \"%.100s%s\"\n", + prompt_str.c_str(), prompt_str.length() > 100 ? "..." : ""); + + mtmd_input_text inp_txt = { + prompt_str.c_str(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = bitmaps.c_ptr(); + + printf("EMBEDDINGS: Calling mtmd_tokenize with %zu bitmaps\n", bitmaps_c_ptr.size()); + + int32_t tokenized = mtmd_tokenize(ctx_server.mctx, + chunks.ptr.get(), + &inp_txt, + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); + + if (tokenized != 0) { + printf("EMBEDDINGS: mtmd_tokenize failed with error: %d\n", tokenized); + throw std::runtime_error("Failed to tokenize prompt"); + } + + printf("EMBEDDINGS: mtmd_tokenize succeeded\n"); + + server_tokens tmp(chunks, true); + printf("EMBEDDINGS: Created server_tokens\n"); + inputs.push_back(std::move(tmp)); + + } else { + // non-multimodal version + printf("EMBEDDINGS: Using non-multimodal tokenization\n"); + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + printf("EMBEDDINGS: Tokenized %zu prompts\n", tokenized_prompts.size()); + + for (auto & p : tokenized_prompts) { + printf("EMBEDDINGS: Prompt tokens: %zu\n", p.size()); + auto tmp = server_tokens(p, ctx_server.mctx != nullptr); + inputs.push_back(std::move(tmp)); + } + } + + printf("EMBEDDINGS: Total inputs created: %zu\n", inputs.size()); + + // Create embedding tasks + std::vector tasks; + std::unordered_set task_ids; + + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.prompt_tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_selected_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = oaicompat; + + tasks.push_back(std::move(task)); + } + + task_ids = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + + printf("EMBEDDINGS: Posted %zu tasks to queue\n", tasks.size()); + + // Wait for results + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + if (results.size() == 1) { + // single result + printf("EMBEDDINGS: Returning single result\n"); + res_ok(res, results[0]->to_json()); + } else { + // multiple results (multitask) + printf("EMBEDDINGS: Returning %zu results\n", results.size()); + json arr = json::array(); + for (auto & res : results) { + arr.push_back(res->to_json()); + } + res_ok(res, arr); + } + }, [&](const json & error_data) { + printf("EMBEDDINGS: Error occurred during processing\n"); + res_error(res, error_data); + }, [&req]() { + return !req.has_header("Connection") || req.get_header_value("Connection") != "keep-alive"; + }); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + printf("EMBEDDINGS: Processing completed\n"); + }; + const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { + std::vector files; + + // Parse the request body here to extract images const json body = json::parse(req.body); - - // for the shape of input/content, see tokenize_input_prompts() - json prompt; - if (body.count("input") != 0) { - prompt = body.at("input"); - } else if (body.contains("content")) { - oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible - prompt = body.at("content"); - } else { - res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - bool use_base64 = false; - if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); - if (format == "base64") { - use_base64 = true; - } else if (format != "float") { - res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - - auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); - for (const auto & tokens : tokenized_prompts) { - // this check is necessary for models that do not add BOS token to the input - if (tokens.empty()) { - res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - - // create and queue the task - json responses = json::array(); - bool error = false; - std::unordered_set task_ids; - { - std::vector tasks; - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr); - - // OAI-compat - task.params.oaicompat = oaicompat; - - tasks.push_back(std::move(task)); + + // Handle simple image field for non-OAI endpoint + if (body.contains("image")) { + std::string image_data = body.at("image"); + if (string_starts_with(image_data, "data:image/")) { + auto parts = string_split(image_data, ','); + auto decoded_data = base64_decode(parts[1]); + files.push_back(decoded_data); } - - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); - } - - // get the result - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - for (auto & res : results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - }, [&](const json & error_data) { - res_error(res, error_data); - error = true; - }, req.is_connection_closed); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - - if (error) { - return; } - - // write JSON response - json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING - ? format_embeddings_response_oaicompat(body, responses, use_base64) - : json(responses); - res_ok(res, root); - }; - - const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); + + handle_embeddings_impl(req, res, files, OAICOMPAT_TYPE_NONE); }; const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); + std::vector files; // dummy files + handle_embeddings_impl(req, res, files, OAICOMPAT_TYPE_EMBEDDING); }; const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {