1+ import json
2+ import os
3+ import signal
4+ import subprocess
5+ import time
6+
7+ import click # type: ignore
8+ import requests # type: ignore
9+ from sklearn .metrics .pairwise import cosine_similarity # type: ignore
10+
11+ from model import LlamaCppServerEmbeddingModel
12+
13+
14+ @click .command ()
15+ @click .option ('--llama-bin' , default = './llama-server' , help = 'Path to llama-server binary' )
16+ @click .option ('--model' , required = True , help = 'Path to model .gguf file' )
17+ @click .option ('--mmproj' , required = True , help = 'Path to mmproj .gguf file' )
18+ @click .option ('--port' , default = 8080 , help = 'Port for llama-server' )
19+ @click .option ('--host' , default = '0.0.0.0' , help = 'Host for llama-server' )
20+ @click .option ('--ngl' , default = 999 , help = 'Number of GPU layers' )
21+ @click .option ('--gpus' , default = '0' , help = 'CUDA_VISIBLE_DEVICES comma separated GPU ids (e.g. "0,1")' )
22+ @click .option ('--input' , 'input_path' , required = True , help = 'Path to input txt file. Format: "[TYPE] content" where TYPE is QUERY, DOCUMENT, or IMAGE. For IMAGE, content should be the file path.' )
23+ @click .option ('--output' , 'output_path' , required = True , help = 'Path to output JSON file for embeddings' )
24+ @click .option ('--normalize-after-pooling' , is_flag = True , default = False , help = 'Apply L2 normalization after pooling' )
25+ @click .option ('--save-cosine-sim-path' , help = 'Path to save cosine similarity matrix as markdown table' )
26+ @click .option ('--query-prefix' , default = 'Query: ' , help = 'Prefix for [QUERY] lines' )
27+ @click .option ('--document-prefix' , default = 'Passage: ' , help = 'Prefix for [DOCUMENT] lines' )
28+ @click .option ('--image-prefix' , default = 'Describe the image.<__image__>' , help = 'Prefix for [IMAGE] lines' )
29+ def main (
30+ llama_bin , model , mmproj , port , host , ngl , gpus ,
31+ input_path , output_path ,
32+ normalize_after_pooling ,
33+ save_cosine_sim_path , query_prefix , document_prefix , image_prefix
34+ ):
35+ env = os .environ .copy ()
36+ env ['CUDA_VISIBLE_DEVICES' ] = gpus
37+
38+ cmd = [
39+ llama_bin ,
40+ '-m' , model ,
41+ '--mmproj' , mmproj ,
42+ '--embedding' ,
43+ '--port' , str (port ),
44+ '-ngl' , str (ngl ),
45+ '--host' , host ,
46+ '--pooling' , 'none'
47+ ]
48+ print (f"Starting llama-server with: { ' ' .join (cmd )} " )
49+ proc = subprocess .Popen (cmd , env = env )
50+
51+ try :
52+ print ("Waiting for server to start..." )
53+
54+ # Health check - wait until server is ready
55+ max_wait_time = 300 # 5 minutes
56+ check_interval = 2 # 2 seconds
57+ start_time = time .time ()
58+
59+ while True :
60+ try :
61+ # Test the actual embedding endpoint with a simple request
62+ test_payload = {"content" : "test" }
63+ health_response = requests .post (f"http://{ host } :{ port } /embedding" , json = test_payload , timeout = 10 )
64+ if health_response .status_code == 200 :
65+ print ("✅ Server is ready!" )
66+ break
67+ elif health_response .status_code == 503 :
68+ elapsed = time .time () - start_time
69+ print (f"⏳ Server still loading model... ({ elapsed :.1f} s elapsed)" )
70+ else :
71+ elapsed = time .time () - start_time
72+ print (f"⚠️ Unexpected server response: { health_response .status_code } ({ elapsed :.1f} s elapsed)" )
73+ except requests .exceptions .RequestException as e :
74+ elapsed = time .time () - start_time
75+ print (f"⏳ Waiting for server to start... ({ elapsed :.1f} s elapsed)" )
76+
77+ # Check if we've exceeded max wait time
78+ if time .time () - start_time > max_wait_time :
79+ raise TimeoutError (f"Server did not become ready within { max_wait_time } seconds" )
80+
81+ time .sleep (check_interval )
82+
83+ with open (input_path , 'r' , encoding = 'utf-8' ) as f :
84+ raw_lines = [line .strip () for line in f if line .strip ()]
85+
86+ print (f"Loaded { len (raw_lines )} sentences from { input_path } " )
87+
88+ model = LlamaCppServerEmbeddingModel (
89+ server_url = f"http://{ host } :{ port } " ,
90+ normalize_after_pooling = normalize_after_pooling ,
91+ query_prefix = query_prefix ,
92+ document_prefix = document_prefix ,
93+ image_prefix = image_prefix
94+ )
95+
96+ original_texts , embeddings = model .encode_from_lines (raw_lines )
97+
98+ output_data = [
99+ {"text" : text , "embedding" : embedding .tolist ()}
100+ for text , embedding in zip (original_texts , embeddings )
101+ ]
102+
103+ with open (output_path , 'w' , encoding = 'utf-8' ) as f_out :
104+ json .dump (output_data , f_out , indent = 2 )
105+
106+ print (f"Saved embeddings to { output_path } " )
107+
108+ # Save cosine similarity matrix if requested
109+ if save_cosine_sim_path :
110+ def clip_text (text , max_len = 10 ):
111+ """Clip text to max_len characters, showing first part + '...' if needed"""
112+ if len (text ) <= max_len :
113+ return text
114+ return text [:max_len - 3 ] + "..."
115+
116+ # Extract display names from original texts
117+ display_names = []
118+ for i , text in enumerate (raw_lines ):
119+ if text .startswith ('[QUERY] ' ):
120+ content = text [8 :]
121+ display_names .append (f"Q:{ clip_text (content )} " )
122+ elif text .startswith ('[DOCUMENT] ' ):
123+ content = text [11 :]
124+ display_names .append (f"D:{ clip_text (content )} " )
125+ elif text .startswith ('[IMAGE] ' ):
126+ image_path = text [8 :]
127+ filename = os .path .basename (image_path )
128+ display_names .append (f"I:{ clip_text (filename )} " )
129+ else :
130+ display_names .append (clip_text (text ))
131+
132+ # Compute cosine similarity matrix
133+ similarity_matrix = cosine_similarity (embeddings )
134+
135+ # Create markdown table
136+ with open (save_cosine_sim_path , 'w' , encoding = 'utf-8' ) as f :
137+ f .write ("# Cosine Similarity Matrix\n \n " )
138+
139+ # Write header row
140+ f .write ("| Item |" )
141+ for name in display_names :
142+ f .write (f" { name } |" )
143+ f .write ("\n " )
144+
145+ # Write separator row
146+ f .write ("|" + "---|" * (len (display_names ) + 1 ) + "\n " )
147+
148+ # Write data rows
149+ for i , row_name in enumerate (display_names ):
150+ f .write (f"| { row_name } |" )
151+ for j in range (len (display_names )):
152+ sim_score = similarity_matrix [i , j ]
153+ f .write (f" { sim_score :.3f} |" )
154+ f .write ("\n " )
155+
156+ print (f"Saved cosine similarity matrix to { save_cosine_sim_path } " )
157+
158+ finally :
159+ print ("Shutting down server..." )
160+ proc .send_signal (signal .SIGINT )
161+ try :
162+ proc .wait (timeout = 10 )
163+ except subprocess .TimeoutExpired :
164+ print ("Server did not shut down in time; killing process." )
165+ proc .kill ()
166+
167+
168+ if __name__ == '__main__' :
169+ main () # type: ignore
0 commit comments