Skip to content

Commit 51f7d7f

Browse files
committed
chore: fix bugs, add readme example
1 parent 0dd0291 commit 51f7d7f

File tree

3 files changed

+99
-97
lines changed

3 files changed

+99
-97
lines changed

jina_embeddings/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Inference example
2+
```bash
3+
python infer.py \
4+
--llama-bin /home/andrei/workspace/llama.cpp/build/bin/llama-server \
5+
--model /home/andrei/workspace/gguf/jev4-bf16.gguf \
6+
--mmproj /home/andrei/workspace/gguf/mmproj-jev4-bf16.gguf \
7+
--gpus 7 \
8+
--input /home/andrei/workspace/test_data.txt \
9+
--output /home/andrei/workspace/jev4_mmtd.json \
10+
--save-cosine-sim-path /home/andrei/workspace/jev4_mmtd.md \
11+
--query-prefix "Query: " \
12+
--document-prefix "Passage: " \
13+
--normalize-after-pooling
14+
```

jina_embeddings/infer.py

Lines changed: 55 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,66 @@
22
import os
33
import signal
44
import subprocess
5-
import time
65

76
import click # type: ignore
8-
import requests # type: ignore
7+
import numpy as np # type: ignore
98
from sklearn.metrics.pairwise import cosine_similarity # type: ignore
109

1110
from model import LlamaCppServerEmbeddingModel
1211

1312

13+
def clip_text(text: str, max_len: int = 10) -> str:
14+
"""Clip text to max_len characters, showing first part + '...' if needed"""
15+
if len(text) <= max_len:
16+
return text
17+
return text[:max_len-3] + "..."
18+
19+
20+
def save_cosine_similarity_matrix(raw_lines: list[str], embeddings: np.ndarray, save_path: str) -> None:
21+
"""Save cosine similarity matrix as markdown table"""
22+
# Extract display names from original texts
23+
display_names = []
24+
for text in raw_lines:
25+
if text.startswith('[QUERY] '):
26+
content = text[8:]
27+
display_names.append(f"Q:{clip_text(content)}")
28+
elif text.startswith('[DOCUMENT] '):
29+
content = text[11:]
30+
display_names.append(f"D:{clip_text(content)}")
31+
elif text.startswith('[IMAGE] '):
32+
image_path = text[8:]
33+
filename = os.path.basename(image_path)
34+
display_names.append(f"I:{clip_text(filename)}")
35+
else:
36+
display_names.append(clip_text(text))
37+
38+
# Compute cosine similarity matrix
39+
similarity_matrix = cosine_similarity(embeddings)
40+
41+
# Create markdown table
42+
with open(save_path, 'w', encoding='utf-8') as f:
43+
f.write("# Cosine Similarity Matrix\n\n")
44+
45+
# Write header row
46+
f.write("| Item |")
47+
for name in display_names:
48+
f.write(f" {name} |")
49+
f.write("\n")
50+
51+
# Write separator row
52+
f.write("|" + "---|" * (len(display_names) + 1) + "\n")
53+
54+
# Write data rows
55+
for i, row_name in enumerate(display_names):
56+
f.write(f"| {row_name} |")
57+
for j in range(len(display_names)):
58+
sim_score = similarity_matrix[i, j]
59+
f.write(f" {sim_score:.3f} |")
60+
f.write("\n")
61+
62+
print(f"Saved cosine similarity matrix to {save_path}")
63+
64+
1465
@click.command()
1566
@click.option('--llama-bin', default='./llama-server', help='Path to llama-server binary')
1667
@click.option('--model', required=True, help='Path to model .gguf file')
@@ -49,37 +100,6 @@ def main(
49100
proc = subprocess.Popen(cmd, env=env)
50101

51102
try:
52-
print("Waiting for server to start...")
53-
54-
# Health check - wait until server is ready
55-
max_wait_time = 300 # 5 minutes
56-
check_interval = 2 # 2 seconds
57-
start_time = time.time()
58-
59-
while True:
60-
try:
61-
# Test the actual embedding endpoint with a simple request
62-
test_payload = {"content": "test"}
63-
health_response = requests.post(f"http://{host}:{port}/embedding", json=test_payload, timeout=10)
64-
if health_response.status_code == 200:
65-
print("✅ Server is ready!")
66-
break
67-
elif health_response.status_code == 503:
68-
elapsed = time.time() - start_time
69-
print(f"⏳ Server still loading model... ({elapsed:.1f}s elapsed)")
70-
else:
71-
elapsed = time.time() - start_time
72-
print(f"⚠️ Unexpected server response: {health_response.status_code} ({elapsed:.1f}s elapsed)")
73-
except requests.exceptions.RequestException as e:
74-
elapsed = time.time() - start_time
75-
print(f"⏳ Waiting for server to start... ({elapsed:.1f}s elapsed)")
76-
77-
# Check if we've exceeded max wait time
78-
if time.time() - start_time > max_wait_time:
79-
raise TimeoutError(f"Server did not become ready within {max_wait_time} seconds")
80-
81-
time.sleep(check_interval)
82-
83103
with open(input_path, 'r', encoding='utf-8') as f:
84104
raw_lines = [line.strip() for line in f if line.strip()]
85105

@@ -93,6 +113,7 @@ def main(
93113
image_prefix=image_prefix
94114
)
95115

116+
model.wait_for_server()
96117
original_texts, embeddings = model.encode_from_lines(raw_lines)
97118

98119
output_data = [
@@ -107,53 +128,7 @@ def main(
107128

108129
# Save cosine similarity matrix if requested
109130
if save_cosine_sim_path:
110-
def clip_text(text, max_len=10):
111-
"""Clip text to max_len characters, showing first part + '...' if needed"""
112-
if len(text) <= max_len:
113-
return text
114-
return text[:max_len-3] + "..."
115-
116-
# Extract display names from original texts
117-
display_names = []
118-
for i, text in enumerate(raw_lines):
119-
if text.startswith('[QUERY] '):
120-
content = text[8:]
121-
display_names.append(f"Q:{clip_text(content)}")
122-
elif text.startswith('[DOCUMENT] '):
123-
content = text[11:]
124-
display_names.append(f"D:{clip_text(content)}")
125-
elif text.startswith('[IMAGE] '):
126-
image_path = text[8:]
127-
filename = os.path.basename(image_path)
128-
display_names.append(f"I:{clip_text(filename)}")
129-
else:
130-
display_names.append(clip_text(text))
131-
132-
# Compute cosine similarity matrix
133-
similarity_matrix = cosine_similarity(embeddings)
134-
135-
# Create markdown table
136-
with open(save_cosine_sim_path, 'w', encoding='utf-8') as f:
137-
f.write("# Cosine Similarity Matrix\n\n")
138-
139-
# Write header row
140-
f.write("| Item |")
141-
for name in display_names:
142-
f.write(f" {name} |")
143-
f.write("\n")
144-
145-
# Write separator row
146-
f.write("|" + "---|" * (len(display_names) + 1) + "\n")
147-
148-
# Write data rows
149-
for i, row_name in enumerate(display_names):
150-
f.write(f"| {row_name} |")
151-
for j in range(len(display_names)):
152-
sim_score = similarity_matrix[i, j]
153-
f.write(f" {sim_score:.3f} |")
154-
f.write("\n")
155-
156-
print(f"Saved cosine similarity matrix to {save_cosine_sim_path}")
131+
save_cosine_similarity_matrix(raw_lines, embeddings, save_cosine_sim_path)
157132

158133
finally:
159134
print("Shutting down server...")

jina_embeddings/model.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import os
3+
import time
34
from typing import List, Optional, Tuple
45

56
import numpy as np # type: ignore
@@ -27,6 +28,25 @@ def __init__(
2728
self.document_prefix = document_prefix
2829
self.image_prefix = image_prefix
2930

31+
def wait_for_server(self, max_wait_time: int = 300, check_interval: int = 2) -> None:
32+
"""Wait for the server to be ready"""
33+
print("Waiting for server to start...")
34+
test_payload = {"content": "test"}
35+
36+
start_time = time.time()
37+
while True:
38+
elapsed = time.time() - start_time
39+
if elapsed > max_wait_time:
40+
raise TimeoutError(f"Server did not become ready within {max_wait_time} seconds")
41+
try:
42+
r = requests.post(f"{self.server_url}/embedding", json=test_payload, timeout=10)
43+
assert r.status_code == 200, f"Server not ready: {r.status_code}"
44+
print("✅ Server is ready!")
45+
break
46+
except (requests.exceptions.RequestException, AssertionError):
47+
print(f"⏳ Waiting for server to start... ({elapsed:.1f}s elapsed)")
48+
time.sleep(check_interval)
49+
3050
def _parse_line(self, line: str) -> Tuple[str, EmbeddingRequestItem]:
3151
"""Parse input line and return (original_content, EmbeddingRequestItem)"""
3252
if line.startswith('[QUERY] '):
@@ -77,23 +97,25 @@ def encode(self, items: List[EmbeddingRequestItem]) -> np.ndarray:
7797
embeddings = []
7898

7999
for i, item in enumerate(items):
80-
payload = {"content": item["content"], "image": item["image"]}
100+
payload = {"content": item["content"]}
101+
if item["image"]:
102+
payload["image"] = item["image"]
103+
81104
is_image_request = item["image"] is not None
82105
response = requests.post(f"{self.server_url}/embedding", json=payload)
83106
assert response.status_code == 200, f"Server error: {response.text}"
84107
embedding_data = response.json()
108+
raw_embedding = embedding_data["embedding"]
85109

110+
# TODO: optional enable logging via argument
86111
print(f"\n==========================")
87112
print(f"🧠 Item {i + 1} embedding response")
88113
print(f"📦 Type: {type(embedding_data).__name__}")
89114
print(f"🔑 Keys: {list(embedding_data.keys())}")
90115
print(f"🔎 Preview: {repr(embedding_data)[:500]}")
91-
print(f"==========================")
92-
93-
raw_embedding = embedding_data["embedding"]
94-
95116
print(f"🔍 Raw embedding type: {type(raw_embedding)}")
96117
print(f"🔍 Raw embedding shape: {np.array(raw_embedding).shape}")
118+
print(f"==========================")
97119

98120
# Check if embeddings are already normalized
99121
embedding_array = np.array(raw_embedding)
@@ -104,27 +126,18 @@ def encode(self, items: List[EmbeddingRequestItem]) -> np.ndarray:
104126
# Handle image token extraction
105127
if is_image_request:
106128
start_idx = embedding_data["start_image_token_idx"]
107-
end_idx = embedding_data["end_image_token_idx"]
108-
109-
print(f"🖼️ Image token indices: start={start_idx}, end={end_idx}")
110-
111-
# Token-level embeddings - extract only image tokens
129+
end_idx = embedding_data["end_image_token_idx"]
112130
hidden_states = np.array(raw_embedding)
113131
image_embeddings = hidden_states[start_idx:end_idx+1] # +1 for inclusive end
114-
132+
pooled = image_embeddings.mean(axis=0)
133+
print(f"🖼️ Image token indices: start={start_idx}, end={end_idx}")
115134
print(f"🖼️ Extracted image embeddings shape: {image_embeddings.shape}")
116135
print(f"🖼️ Original total embeddings: {len(raw_embedding)}")
117136
print(f"🖼️ Image embeddings extracted: {len(image_embeddings)}")
118-
119-
# Pool only the image embeddings (always mean pool)
120-
pooled = image_embeddings.mean(axis=0)
121-
print(f"🖼️ Using mean pooling of image tokens")
122-
123137
else:
124138
# Regular text processing - always mean pool the tokens
125139
hidden_states = np.array(raw_embedding)
126140
pooled = hidden_states.mean(axis=0)
127-
print(f"📊 Applied mean pooling")
128141

129142
# Optional normalization
130143
if self.normalize_after_pooling:

0 commit comments

Comments
 (0)