Skip to content

Commit 0dd0291

Browse files
committed
feat: add inference scripts
1 parent 10889de commit 0dd0291

File tree

3 files changed

+323
-0
lines changed

3 files changed

+323
-0
lines changed

jina_embeddings/README.md

Whitespace-only changes.

jina_embeddings/infer.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import json
2+
import os
3+
import signal
4+
import subprocess
5+
import time
6+
7+
import click # type: ignore
8+
import requests # type: ignore
9+
from sklearn.metrics.pairwise import cosine_similarity # type: ignore
10+
11+
from model import LlamaCppServerEmbeddingModel
12+
13+
14+
@click.command()
15+
@click.option('--llama-bin', default='./llama-server', help='Path to llama-server binary')
16+
@click.option('--model', required=True, help='Path to model .gguf file')
17+
@click.option('--mmproj', required=True, help='Path to mmproj .gguf file')
18+
@click.option('--port', default=8080, help='Port for llama-server')
19+
@click.option('--host', default='0.0.0.0', help='Host for llama-server')
20+
@click.option('--ngl', default=999, help='Number of GPU layers')
21+
@click.option('--gpus', default='0', help='CUDA_VISIBLE_DEVICES comma separated GPU ids (e.g. "0,1")')
22+
@click.option('--input', 'input_path', required=True, help='Path to input txt file. Format: "[TYPE] content" where TYPE is QUERY, DOCUMENT, or IMAGE. For IMAGE, content should be the file path.')
23+
@click.option('--output', 'output_path', required=True, help='Path to output JSON file for embeddings')
24+
@click.option('--normalize-after-pooling', is_flag=True, default=False, help='Apply L2 normalization after pooling')
25+
@click.option('--save-cosine-sim-path', help='Path to save cosine similarity matrix as markdown table')
26+
@click.option('--query-prefix', default='Query: ', help='Prefix for [QUERY] lines')
27+
@click.option('--document-prefix', default='Passage: ', help='Prefix for [DOCUMENT] lines')
28+
@click.option('--image-prefix', default='Describe the image.<__image__>', help='Prefix for [IMAGE] lines')
29+
def main(
30+
llama_bin, model, mmproj, port, host, ngl, gpus,
31+
input_path, output_path,
32+
normalize_after_pooling,
33+
save_cosine_sim_path, query_prefix, document_prefix, image_prefix
34+
):
35+
env = os.environ.copy()
36+
env['CUDA_VISIBLE_DEVICES'] = gpus
37+
38+
cmd = [
39+
llama_bin,
40+
'-m', model,
41+
'--mmproj', mmproj,
42+
'--embedding',
43+
'--port', str(port),
44+
'-ngl', str(ngl),
45+
'--host', host,
46+
'--pooling', 'none'
47+
]
48+
print(f"Starting llama-server with: {' '.join(cmd)}")
49+
proc = subprocess.Popen(cmd, env=env)
50+
51+
try:
52+
print("Waiting for server to start...")
53+
54+
# Health check - wait until server is ready
55+
max_wait_time = 300 # 5 minutes
56+
check_interval = 2 # 2 seconds
57+
start_time = time.time()
58+
59+
while True:
60+
try:
61+
# Test the actual embedding endpoint with a simple request
62+
test_payload = {"content": "test"}
63+
health_response = requests.post(f"http://{host}:{port}/embedding", json=test_payload, timeout=10)
64+
if health_response.status_code == 200:
65+
print("✅ Server is ready!")
66+
break
67+
elif health_response.status_code == 503:
68+
elapsed = time.time() - start_time
69+
print(f"⏳ Server still loading model... ({elapsed:.1f}s elapsed)")
70+
else:
71+
elapsed = time.time() - start_time
72+
print(f"⚠️ Unexpected server response: {health_response.status_code} ({elapsed:.1f}s elapsed)")
73+
except requests.exceptions.RequestException as e:
74+
elapsed = time.time() - start_time
75+
print(f"⏳ Waiting for server to start... ({elapsed:.1f}s elapsed)")
76+
77+
# Check if we've exceeded max wait time
78+
if time.time() - start_time > max_wait_time:
79+
raise TimeoutError(f"Server did not become ready within {max_wait_time} seconds")
80+
81+
time.sleep(check_interval)
82+
83+
with open(input_path, 'r', encoding='utf-8') as f:
84+
raw_lines = [line.strip() for line in f if line.strip()]
85+
86+
print(f"Loaded {len(raw_lines)} sentences from {input_path}")
87+
88+
model = LlamaCppServerEmbeddingModel(
89+
server_url=f"http://{host}:{port}",
90+
normalize_after_pooling=normalize_after_pooling,
91+
query_prefix=query_prefix,
92+
document_prefix=document_prefix,
93+
image_prefix=image_prefix
94+
)
95+
96+
original_texts, embeddings = model.encode_from_lines(raw_lines)
97+
98+
output_data = [
99+
{"text": text, "embedding": embedding.tolist()}
100+
for text, embedding in zip(original_texts, embeddings)
101+
]
102+
103+
with open(output_path, 'w', encoding='utf-8') as f_out:
104+
json.dump(output_data, f_out, indent=2)
105+
106+
print(f"Saved embeddings to {output_path}")
107+
108+
# Save cosine similarity matrix if requested
109+
if save_cosine_sim_path:
110+
def clip_text(text, max_len=10):
111+
"""Clip text to max_len characters, showing first part + '...' if needed"""
112+
if len(text) <= max_len:
113+
return text
114+
return text[:max_len-3] + "..."
115+
116+
# Extract display names from original texts
117+
display_names = []
118+
for i, text in enumerate(raw_lines):
119+
if text.startswith('[QUERY] '):
120+
content = text[8:]
121+
display_names.append(f"Q:{clip_text(content)}")
122+
elif text.startswith('[DOCUMENT] '):
123+
content = text[11:]
124+
display_names.append(f"D:{clip_text(content)}")
125+
elif text.startswith('[IMAGE] '):
126+
image_path = text[8:]
127+
filename = os.path.basename(image_path)
128+
display_names.append(f"I:{clip_text(filename)}")
129+
else:
130+
display_names.append(clip_text(text))
131+
132+
# Compute cosine similarity matrix
133+
similarity_matrix = cosine_similarity(embeddings)
134+
135+
# Create markdown table
136+
with open(save_cosine_sim_path, 'w', encoding='utf-8') as f:
137+
f.write("# Cosine Similarity Matrix\n\n")
138+
139+
# Write header row
140+
f.write("| Item |")
141+
for name in display_names:
142+
f.write(f" {name} |")
143+
f.write("\n")
144+
145+
# Write separator row
146+
f.write("|" + "---|" * (len(display_names) + 1) + "\n")
147+
148+
# Write data rows
149+
for i, row_name in enumerate(display_names):
150+
f.write(f"| {row_name} |")
151+
for j in range(len(display_names)):
152+
sim_score = similarity_matrix[i, j]
153+
f.write(f" {sim_score:.3f} |")
154+
f.write("\n")
155+
156+
print(f"Saved cosine similarity matrix to {save_cosine_sim_path}")
157+
158+
finally:
159+
print("Shutting down server...")
160+
proc.send_signal(signal.SIGINT)
161+
try:
162+
proc.wait(timeout=10)
163+
except subprocess.TimeoutExpired:
164+
print("Server did not shut down in time; killing process.")
165+
proc.kill()
166+
167+
168+
if __name__ == '__main__':
169+
main() # type: ignore

jina_embeddings/model.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import base64
2+
import os
3+
from typing import List, Optional, Tuple
4+
5+
import numpy as np # type: ignore
6+
import requests # type: ignore
7+
from typing_extensions import TypedDict # type: ignore
8+
9+
10+
class EmbeddingRequestItem(TypedDict):
11+
content: str
12+
image: Optional[str]
13+
14+
15+
class LlamaCppServerEmbeddingModel:
16+
def __init__(
17+
self,
18+
server_url: str = "http://localhost:8080",
19+
normalize_after_pooling: bool = False,
20+
query_prefix: str = "Query: ",
21+
document_prefix: str = "Passage: ",
22+
image_prefix: str = "Describe the image.<__image__>"
23+
) -> None:
24+
self.server_url = server_url
25+
self.normalize_after_pooling = normalize_after_pooling
26+
self.query_prefix = query_prefix
27+
self.document_prefix = document_prefix
28+
self.image_prefix = image_prefix
29+
30+
def _parse_line(self, line: str) -> Tuple[str, EmbeddingRequestItem]:
31+
"""Parse input line and return (original_content, EmbeddingRequestItem)"""
32+
if line.startswith('[QUERY] '):
33+
content = line[8:] # Remove '[QUERY] '
34+
item: EmbeddingRequestItem = { "content": self.query_prefix + content, "image": None }
35+
return content, item
36+
elif line.startswith('[DOCUMENT] '):
37+
content = line[11:] # Remove '[DOCUMENT] '
38+
item: EmbeddingRequestItem = { "content": self.document_prefix + content, "image": None }
39+
return content, item
40+
elif line.startswith('[IMAGE] '):
41+
image_path = line[8:] # Remove '[IMAGE] '
42+
data_url, success = self._process_image(image_path)
43+
assert success, f"Failed to process image: {image_path}"
44+
item: EmbeddingRequestItem = { "content": self.image_prefix, "image": data_url }
45+
return image_path, item
46+
else:
47+
raise ValueError(f"Invalid line format: {line}. Expected '[QUERY] ', '[DOCUMENT] ', or '[IMAGE] ' prefix.")
48+
49+
def _process_image(self, image_path: str) -> Tuple[Optional[str], bool]:
50+
"""Process image file and return (data_url, success)"""
51+
try:
52+
with open(image_path, 'rb') as img_file:
53+
image_data = base64.b64encode(img_file.read()).decode('utf-8')
54+
55+
# Detect image format from extension
56+
ext = os.path.splitext(image_path)[1].lower()
57+
if ext in ['.jpg', '.jpeg']:
58+
mime_type = 'image/jpeg'
59+
elif ext == '.png':
60+
mime_type = 'image/png'
61+
elif ext == '.webp':
62+
mime_type = 'image/webp'
63+
else:
64+
mime_type = 'image/jpeg' # default
65+
66+
data_url = f"data:{mime_type};base64,{image_data}"
67+
return data_url, True
68+
69+
except FileNotFoundError:
70+
print(f"❌ Image not found: {image_path}, processing as text only")
71+
return None, False
72+
73+
def encode(self, items: List[EmbeddingRequestItem]) -> np.ndarray:
74+
"""
75+
Encode items. Each item should be an EmbeddingRequestItem.
76+
"""
77+
embeddings = []
78+
79+
for i, item in enumerate(items):
80+
payload = {"content": item["content"], "image": item["image"]}
81+
is_image_request = item["image"] is not None
82+
response = requests.post(f"{self.server_url}/embedding", json=payload)
83+
assert response.status_code == 200, f"Server error: {response.text}"
84+
embedding_data = response.json()
85+
86+
print(f"\n==========================")
87+
print(f"🧠 Item {i + 1} embedding response")
88+
print(f"📦 Type: {type(embedding_data).__name__}")
89+
print(f"🔑 Keys: {list(embedding_data.keys())}")
90+
print(f"🔎 Preview: {repr(embedding_data)[:500]}")
91+
print(f"==========================")
92+
93+
raw_embedding = embedding_data["embedding"]
94+
95+
print(f"🔍 Raw embedding type: {type(raw_embedding)}")
96+
print(f"🔍 Raw embedding shape: {np.array(raw_embedding).shape}")
97+
98+
# Check if embeddings are already normalized
99+
embedding_array = np.array(raw_embedding)
100+
norms = np.linalg.norm(embedding_array, axis=1)
101+
if np.allclose(norms, 1.0, atol=1e-6):
102+
print(f"⚠️ WARNING: Raw embeddings appear to be already normalized!")
103+
104+
# Handle image token extraction
105+
if is_image_request:
106+
start_idx = embedding_data["start_image_token_idx"]
107+
end_idx = embedding_data["end_image_token_idx"]
108+
109+
print(f"🖼️ Image token indices: start={start_idx}, end={end_idx}")
110+
111+
# Token-level embeddings - extract only image tokens
112+
hidden_states = np.array(raw_embedding)
113+
image_embeddings = hidden_states[start_idx:end_idx+1] # +1 for inclusive end
114+
115+
print(f"🖼️ Extracted image embeddings shape: {image_embeddings.shape}")
116+
print(f"🖼️ Original total embeddings: {len(raw_embedding)}")
117+
print(f"🖼️ Image embeddings extracted: {len(image_embeddings)}")
118+
119+
# Pool only the image embeddings (always mean pool)
120+
pooled = image_embeddings.mean(axis=0)
121+
print(f"🖼️ Using mean pooling of image tokens")
122+
123+
else:
124+
# Regular text processing - always mean pool the tokens
125+
hidden_states = np.array(raw_embedding)
126+
pooled = hidden_states.mean(axis=0)
127+
print(f"📊 Applied mean pooling")
128+
129+
# Optional normalization
130+
if self.normalize_after_pooling:
131+
norm = np.linalg.norm(pooled)
132+
if norm > 0:
133+
pooled = pooled / norm
134+
print(f"🔄 Applied L2 normalization")
135+
136+
embeddings.append(pooled)
137+
138+
return np.array(embeddings)
139+
140+
def encode_from_lines(self, raw_lines: List[str]) -> Tuple[List[str], np.ndarray]:
141+
"""
142+
Process raw lines with type prefixes and return embeddings along with original content
143+
Returns: (original_texts, embeddings)
144+
"""
145+
original_texts = []
146+
items = []
147+
148+
for line in raw_lines:
149+
original, item = self._parse_line(line.strip())
150+
original_texts.append(original)
151+
items.append(item)
152+
153+
embeddings = self.encode(items)
154+
return original_texts, embeddings

0 commit comments

Comments
 (0)