Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions jina_embeddings/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Inference example
```bash
python infer.py \
--llama-bin /home/andrei/workspace/llama.cpp/build/bin/llama-server \
--model /home/andrei/workspace/gguf/jev4-bf16.gguf \
--mmproj /home/andrei/workspace/gguf/mmproj-jev4-bf16.gguf \
--gpus 7 \
--input /home/andrei/workspace/test_data.txt \
--output /home/andrei/workspace/jev4_mmtd.json \
--save-cosine-sim-path /home/andrei/workspace/jev4_mmtd.md \
--query-prefix "Query: " \
--document-prefix "Passage: " \
--normalize-after-pooling
```
144 changes: 144 additions & 0 deletions jina_embeddings/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import json
import os
import signal
import subprocess

import click # type: ignore
import numpy as np # type: ignore
from sklearn.metrics.pairwise import cosine_similarity # type: ignore

from model import LlamaCppServerEmbeddingModel


def clip_text(text: str, max_len: int = 10) -> str:
"""Clip text to max_len characters, showing first part + '...' if needed"""
if len(text) <= max_len:
return text
return text[:max_len-3] + "..."


def save_cosine_similarity_matrix(raw_lines: list[str], embeddings: np.ndarray, save_path: str) -> None:
"""Save cosine similarity matrix as markdown table"""
# Extract display names from original texts
display_names = []
for text in raw_lines:
if text.startswith('[QUERY] '):
content = text[8:]
display_names.append(f"Q:{clip_text(content)}")
elif text.startswith('[DOCUMENT] '):
content = text[11:]
display_names.append(f"D:{clip_text(content)}")
elif text.startswith('[IMAGE] '):
image_path = text[8:]
filename = os.path.basename(image_path)
display_names.append(f"I:{clip_text(filename)}")
else:
display_names.append(clip_text(text))

# Compute cosine similarity matrix
similarity_matrix = cosine_similarity(embeddings)

# Create markdown table
with open(save_path, 'w', encoding='utf-8') as f:
f.write("# Cosine Similarity Matrix\n\n")

# Write header row
f.write("| Item |")
for name in display_names:
f.write(f" {name} |")
f.write("\n")

# Write separator row
f.write("|" + "---|" * (len(display_names) + 1) + "\n")

# Write data rows
for i, row_name in enumerate(display_names):
f.write(f"| {row_name} |")
for j in range(len(display_names)):
sim_score = similarity_matrix[i, j]
f.write(f" {sim_score:.3f} |")
f.write("\n")

print(f"Saved cosine similarity matrix to {save_path}")


@click.command()
@click.option('--llama-bin', default='./llama-server', help='Path to llama-server binary')
@click.option('--model', required=True, help='Path to model .gguf file')
@click.option('--mmproj', required=True, help='Path to mmproj .gguf file')
@click.option('--port', default=8080, help='Port for llama-server')
@click.option('--host', default='0.0.0.0', help='Host for llama-server')
@click.option('--ngl', default=999, help='Number of GPU layers')
@click.option('--gpus', default='0', help='CUDA_VISIBLE_DEVICES comma separated GPU ids (e.g. "0,1")')
@click.option('--input', 'input_path', required=True, help='Path to input txt file. Format: "[TYPE] content" where TYPE is QUERY, DOCUMENT, or IMAGE. For IMAGE, content should be the file path.')
@click.option('--output', 'output_path', required=True, help='Path to output JSON file for embeddings')
@click.option('--normalize-after-pooling', is_flag=True, default=False, help='Apply L2 normalization after pooling')
@click.option('--save-cosine-sim-path', help='Path to save cosine similarity matrix as markdown table')
@click.option('--query-prefix', default='Query: ', help='Prefix for [QUERY] lines')
@click.option('--document-prefix', default='Passage: ', help='Prefix for [DOCUMENT] lines')
@click.option('--image-prefix', default='Describe the image.<__image__>', help='Prefix for [IMAGE] lines')
def main(
llama_bin, model, mmproj, port, host, ngl, gpus,
input_path, output_path,
normalize_after_pooling,
save_cosine_sim_path, query_prefix, document_prefix, image_prefix
):
env = os.environ.copy()
env['CUDA_VISIBLE_DEVICES'] = gpus

cmd = [
llama_bin,
'-m', model,
'--mmproj', mmproj,
'--embedding',
'--port', str(port),
'-ngl', str(ngl),
'--host', host,
'--pooling', 'none'
]
print(f"Starting llama-server with: {' '.join(cmd)}")
proc = subprocess.Popen(cmd, env=env)

try:
with open(input_path, 'r', encoding='utf-8') as f:
raw_lines = [line.strip() for line in f if line.strip()]

print(f"Loaded {len(raw_lines)} sentences from {input_path}")

model = LlamaCppServerEmbeddingModel(
server_url=f"http://{host}:{port}",
normalize_after_pooling=normalize_after_pooling,
query_prefix=query_prefix,
document_prefix=document_prefix,
image_prefix=image_prefix
)

model.wait_for_server()
original_texts, embeddings = model.encode_from_lines(raw_lines)

output_data = [
{"text": text, "embedding": embedding.tolist()}
for text, embedding in zip(original_texts, embeddings)
]

with open(output_path, 'w', encoding='utf-8') as f_out:
json.dump(output_data, f_out, indent=2)

print(f"Saved embeddings to {output_path}")

# Save cosine similarity matrix if requested
if save_cosine_sim_path:
save_cosine_similarity_matrix(raw_lines, embeddings, save_cosine_sim_path)

finally:
print("Shutting down server...")
proc.send_signal(signal.SIGINT)
try:
proc.wait(timeout=10)
except subprocess.TimeoutExpired:
print("Server did not shut down in time; killing process.")
proc.kill()


if __name__ == '__main__':
main() # type: ignore
167 changes: 167 additions & 0 deletions jina_embeddings/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import base64
import os
import time
from typing import List, Optional, Tuple

import numpy as np # type: ignore
import requests # type: ignore
from typing_extensions import TypedDict # type: ignore


class EmbeddingRequestItem(TypedDict):
content: str
image: Optional[str]


class LlamaCppServerEmbeddingModel:
def __init__(
self,
server_url: str = "http://localhost:8080",
normalize_after_pooling: bool = False,
query_prefix: str = "Query: ",
document_prefix: str = "Passage: ",
image_prefix: str = "Describe the image.<__image__>"
) -> None:
self.server_url = server_url
self.normalize_after_pooling = normalize_after_pooling
self.query_prefix = query_prefix
self.document_prefix = document_prefix
self.image_prefix = image_prefix

def wait_for_server(self, max_wait_time: int = 300, check_interval: int = 2) -> None:
"""Wait for the server to be ready"""
print("Waiting for server to start...")
test_payload = {"content": "test"}

start_time = time.time()
while True:
elapsed = time.time() - start_time
if elapsed > max_wait_time:
raise TimeoutError(f"Server did not become ready within {max_wait_time} seconds")
try:
r = requests.post(f"{self.server_url}/embedding", json=test_payload, timeout=10)
assert r.status_code == 200, f"Server not ready: {r.status_code}"
print("✅ Server is ready!")
break
except (requests.exceptions.RequestException, AssertionError):
print(f"⏳ Waiting for server to start... ({elapsed:.1f}s elapsed)")
time.sleep(check_interval)

def _parse_line(self, line: str) -> Tuple[str, EmbeddingRequestItem]:
"""Parse input line and return (original_content, EmbeddingRequestItem)"""
if line.startswith('[QUERY] '):
content = line[8:] # Remove '[QUERY] '
item: EmbeddingRequestItem = { "content": self.query_prefix + content, "image": None }
return content, item
elif line.startswith('[DOCUMENT] '):
content = line[11:] # Remove '[DOCUMENT] '
item: EmbeddingRequestItem = { "content": self.document_prefix + content, "image": None }
return content, item
elif line.startswith('[IMAGE] '):
image_path = line[8:] # Remove '[IMAGE] '
data_url, success = self._process_image(image_path)
assert success, f"Failed to process image: {image_path}"
item: EmbeddingRequestItem = { "content": self.image_prefix, "image": data_url }
return image_path, item
else:
raise ValueError(f"Invalid line format: {line}. Expected '[QUERY] ', '[DOCUMENT] ', or '[IMAGE] ' prefix.")

def _process_image(self, image_path: str) -> Tuple[Optional[str], bool]:
"""Process image file and return (data_url, success)"""
try:
with open(image_path, 'rb') as img_file:
image_data = base64.b64encode(img_file.read()).decode('utf-8')

# Detect image format from extension
ext = os.path.splitext(image_path)[1].lower()
if ext in ['.jpg', '.jpeg']:
mime_type = 'image/jpeg'
elif ext == '.png':
mime_type = 'image/png'
elif ext == '.webp':
mime_type = 'image/webp'
else:
mime_type = 'image/jpeg' # default

data_url = f"data:{mime_type};base64,{image_data}"
return data_url, True

except FileNotFoundError:
print(f"❌ Image not found: {image_path}, processing as text only")
return None, False

def encode(self, items: List[EmbeddingRequestItem]) -> np.ndarray:
"""
Encode items. Each item should be an EmbeddingRequestItem.
"""
embeddings = []

for i, item in enumerate(items):
payload = {"content": item["content"]}
if item["image"]:
payload["image"] = item["image"]

is_image_request = item["image"] is not None
response = requests.post(f"{self.server_url}/embedding", json=payload)
assert response.status_code == 200, f"Server error: {response.text}"
embedding_data = response.json()
raw_embedding = embedding_data["embedding"]

# TODO: optional enable logging via argument
print(f"\n==========================")
print(f"🧠 Item {i + 1} embedding response")
print(f"📦 Type: {type(embedding_data).__name__}")
print(f"🔑 Keys: {list(embedding_data.keys())}")
print(f"🔎 Preview: {repr(embedding_data)[:500]}")
print(f"🔍 Raw embedding type: {type(raw_embedding)}")
print(f"🔍 Raw embedding shape: {np.array(raw_embedding).shape}")
print(f"==========================")

# Check if embeddings are already normalized
embedding_array = np.array(raw_embedding)
norms = np.linalg.norm(embedding_array, axis=1)
if np.allclose(norms, 1.0, atol=1e-6):
print(f"⚠️ WARNING: Raw embeddings appear to be already normalized!")

# Handle image token extraction
if is_image_request:
start_idx = embedding_data["start_image_token_idx"]
end_idx = embedding_data["end_image_token_idx"]
hidden_states = np.array(raw_embedding)
image_embeddings = hidden_states[start_idx:end_idx+1] # +1 for inclusive end
pooled = image_embeddings.mean(axis=0)
print(f"🖼️ Image token indices: start={start_idx}, end={end_idx}")
print(f"🖼️ Extracted image embeddings shape: {image_embeddings.shape}")
print(f"🖼️ Original total embeddings: {len(raw_embedding)}")
print(f"🖼️ Image embeddings extracted: {len(image_embeddings)}")
else:
# Regular text processing - always mean pool the tokens
hidden_states = np.array(raw_embedding)
pooled = hidden_states.mean(axis=0)

# Optional normalization
if self.normalize_after_pooling:
norm = np.linalg.norm(pooled)
if norm > 0:
pooled = pooled / norm
print(f"🔄 Applied L2 normalization")

embeddings.append(pooled)

return np.array(embeddings)

def encode_from_lines(self, raw_lines: List[str]) -> Tuple[List[str], np.ndarray]:
"""
Process raw lines with type prefixes and return embeddings along with original content
Returns: (original_texts, embeddings)
"""
original_texts = []
items = []

for line in raw_lines:
original, item = self._parse_line(line.strip())
original_texts.append(original)
items.append(item)

embeddings = self.encode(items)
return original_texts, embeddings
Loading