77import traceback
88import threading
99from openai import OpenAI
10- import requests
1110
1211from utils .config import CFG
1312
@@ -117,6 +116,7 @@ class EmbeddingClient:
117116 """
118117 Embedding client with detailed logging, retry logic, and configurable timeouts.
119118 Provides better debugging for embedding API failures.
119+ Uses OpenAI SDK for proper API compatibility.
120120 """
121121 def __init__ (self ,
122122 api_url : Optional [str ] = None ,
@@ -131,30 +131,34 @@ def __init__(self,
131131 self .timeout = timeout
132132 self .max_retries = max_retries
133133 self .backoff = backoff
134- self . session = requests . Session ()
135- if self . api_key :
136- self . session . headers . update ({ "Authorization" : f"Bearer { self . api_key } " })
137- self .session . headers . update ({ "Content-Type" : "application/json" })
134+
135+ # Use OpenAI SDK client instead of raw requests
136+ # The SDK automatically handles the /embeddings path
137+ self .client = _client
138138
139- def _generate_curl_command (self , url : str , headers : Dict [ str , str ], payload : Dict [str , Any ]) -> str :
139+ def _generate_curl_command (self , payload : Dict [str , Any ]) -> str :
140140 """
141141 Generate a curl command for debugging purposes.
142142 Masks the API key for security.
143143 """
144+ # Construct the full embeddings URL
145+ base_url = self .api_url .rstrip ('/' )
146+ if not base_url .endswith ('/embeddings' ):
147+ url = f"{ base_url } /embeddings"
148+ else :
149+ url = base_url
150+
144151 # Start with basic curl command
145152 curl_parts = ["curl" , "-X" , "POST" , f"'{ url } '" ]
146153
147- # Add headers
154+ # Add standard headers
155+ headers = {
156+ "Content-Type" : "application/json" ,
157+ "Authorization" : f"Bearer <API_KEY_MASKED>"
158+ }
159+
148160 for key , value in headers .items ():
149- if key .lower () == "authorization" and value :
150- # Mask the API key for security
151- if value .startswith ("Bearer " ):
152- masked_value = f"Bearer <API_KEY_MASKED>"
153- else :
154- masked_value = "<API_KEY_MASKED>"
155- curl_parts .append (f"-H '{ key } : { masked_value } '" )
156- else :
157- curl_parts .append (f"-H '{ key } : { value } '" )
161+ curl_parts .append (f"-H '{ key } : { value } '" )
158162
159163 # Add data payload
160164 payload_json = json .dumps (payload )
@@ -225,7 +229,7 @@ def _log_request_end(self, request_id: str, elapsed: float, status: Optional[int
225229
226230 def embed_text (self , text : str , file_path : str = "<unknown>" , chunk_index : int = 0 ) -> List [float ]:
227231 """
228- Embed a single chunk of text. Returns the embedding vector.
232+ Embed a single chunk of text using OpenAI SDK . Returns the embedding vector.
229233 Raises EmbeddingError on failure.
230234 """
231235 request_id = str (uuid .uuid4 ())
@@ -243,75 +247,41 @@ def embed_text(self, text: str, file_path: str = "<unknown>", chunk_index: int =
243247 attempt += 1
244248 start = time .perf_counter ()
245249 try :
246- resp = self .session .post (
247- self .api_url ,
248- data = json .dumps (payload ),
249- timeout = self .timeout ,
250+ # Use OpenAI SDK for embeddings
251+ resp = self .client .embeddings .create (
252+ model = self .model ,
253+ input = text ,
254+ timeout = self .timeout
250255 )
251256 elapsed = time .perf_counter () - start
252257
253- # Try to parse JSON safely
254- try :
255- resp_json = resp .json ()
256- except Exception :
257- resp_json = None
258-
259- preview = ""
260- if resp_json is not None :
261- preview = json .dumps (resp_json )[:1000 ]
262- else :
263- preview = (resp .text or "" )[:1000 ]
264-
265- self ._log_request_end (request_id , elapsed , resp .status_code , preview )
266-
267- if resp .status_code >= 200 and resp .status_code < 300 :
268- # expected format: {"data": [{"embedding": [...]}], ...}
269- if not resp_json :
270- raise EmbeddingError (f"Empty JSON response (status={ resp .status_code } )" )
271- try :
272- # tolerant extraction
273- data = resp_json .get ("data" ) if isinstance (resp_json , dict ) else None
274- if data and isinstance (data , list ) and len (data ) > 0 :
275- emb = data [0 ].get ("embedding" )
276- if emb and isinstance (emb , list ):
277- _embedding_logger .info (
278- "Embedding succeeded" ,
279- extra = {"request_id" : request_id , "file" : file_path , "chunk_index" : chunk_index },
280- )
281- return emb
282- # Fallback: maybe top-level "embedding" key
283- if isinstance (resp_json , dict ) and "embedding" in resp_json :
284- emb = resp_json ["embedding" ]
285- if isinstance (emb , list ):
286- return emb
287- raise EmbeddingError (f"Unexpected embedding response shape: { resp_json } " )
288- except KeyError as e :
289- raise EmbeddingError (f"Missing keys in embedding response: { e } " )
258+ # Log successful response
259+ self ._log_request_end (request_id , elapsed , 200 , "Success" )
260+
261+ # Extract embedding from response
262+ # The SDK returns a response object with a data list
263+ if resp and hasattr (resp , 'data' ) and len (resp .data ) > 0 :
264+ embedding = resp .data [0 ].embedding
265+ if embedding and isinstance (embedding , list ):
266+ _embedding_logger .info (
267+ "Embedding succeeded" ,
268+ extra = {"request_id" : request_id , "file" : file_path , "chunk_index" : chunk_index },
269+ )
270+ return embedding
271+ else :
272+ raise EmbeddingError (f"Invalid embedding format in response" )
290273 else :
291- # Non-2xx
292- _embedding_logger .warning (
293- "Embedding API returned non-2xx" ,
294- extra = {
295- "request_id" : request_id ,
296- "status_code" : resp .status_code ,
297- "file" : file_path ,
298- "chunk_index" : chunk_index ,
299- "attempt" : attempt ,
300- "body_preview" : preview ,
301- },
302- )
303- # fall through to retry logic
304- err_msg = f"Status { resp .status_code } : { preview } "
305-
306- except requests .Timeout as e :
274+ raise EmbeddingError (f"Unexpected embedding response shape from SDK" )
275+
276+ except Exception as e :
307277 elapsed = time .perf_counter () - start
308- err_msg = f"Timeout after { elapsed :.2f} s: { e } "
278+ err_msg = f"Error after { elapsed :.2f} s: { e } "
309279
310- # Save to bash script in /tmp if DEBUG is enabled
280+ # Save debug information for timeout or API errors
311281 script_path = None
312282 if CFG .get ("debug" ):
313283 # Generate curl command for debugging
314- curl_command = self ._generate_curl_command (self . api_url , dict ( self . session . headers ), payload )
284+ curl_command = self ._generate_curl_command (payload )
315285 script_path = self ._save_curl_script (curl_command , request_id , file_path , chunk_index )
316286 if script_path :
317287 _embedding_logger .error (f"\n Debug script saved to: { script_path } " )
@@ -321,24 +291,16 @@ def embed_text(self, text: str, file_path: str = "<unknown>", chunk_index: int =
321291 _embedding_logger .error (curl_command )
322292
323293 _embedding_logger .error (
324- "Embedding API Timeout " ,
294+ "Embedding API Error " ,
325295 extra = {
326296 "request_id" : request_id ,
327297 "error" : str (e ),
328298 "elapsed_s" : elapsed ,
329- "curl_command" : curl_command ,
330- "debug_script" : script_path
299+ "attempt" : attempt ,
300+ "file" : file_path ,
301+ "chunk_index" : chunk_index ,
331302 }
332303 )
333-
334- except requests .RequestException as e :
335- elapsed = time .perf_counter () - start
336- err_msg = f"RequestException after { elapsed :.2f} s: { e } \n { traceback .format_exc ()} "
337- _embedding_logger .error ("Embedding request exception" , extra = {"request_id" : request_id , "error" : err_msg })
338- except Exception as e :
339- elapsed = time .perf_counter () - start
340- err_msg = f"Unexpected error after { elapsed :.2f} s: { e } \n { traceback .format_exc ()} "
341- _embedding_logger .exception ("Unexpected embedding exception" , extra = {"request_id" : request_id })
342304
343305 # Retry logic
344306 if attempt > self .max_retries :
0 commit comments