@@ -259,6 +259,17 @@ def get_api_provider_stream_iter(
259259 api_key = model_api_dict ["api_key" ],
260260 extra_body = extra_body ,
261261 )
262+ elif model_api_dict ["api_type" ] == "critique-labs-ai" :
263+ prompt = conv .to_openai_api_messages ()
264+ stream_iter = critique_api_stream_iter (
265+ model_api_dict ["model_name" ],
266+ prompt ,
267+ temperature ,
268+ top_p ,
269+ max_new_tokens ,
270+ api_key = model_api_dict .get ("api_key" ),
271+ api_base = model_api_dict .get ("api_base" ),
272+ )
262273 else :
263274 raise NotImplementedError ()
264275
@@ -1345,3 +1356,163 @@ def metagen_api_stream_iter(
13451356 "text" : f"**API REQUEST ERROR** Reason: Unknown." ,
13461357 "error_code" : 1 ,
13471358 }
1359+
1360+
1361+ def critique_api_stream_iter (
1362+ model_name ,
1363+ messages ,
1364+ temperature ,
1365+ top_p ,
1366+ max_new_tokens ,
1367+ api_key = None ,
1368+ api_base = None ,
1369+ ):
1370+ import websockets
1371+ import threading
1372+ import queue
1373+ import json
1374+ import time
1375+
1376+ api_key = api_key or os .environ .get ("CRITIQUE_API_KEY" )
1377+ if not api_key :
1378+ yield {
1379+ "text" : "**API REQUEST ERROR** Reason: CRITIQUE_API_KEY not found in environment variables." ,
1380+ "error_code" : 1 ,
1381+ }
1382+ return
1383+
1384+ # Combine all messages into a single prompt
1385+ prompt = ""
1386+ for message in messages :
1387+ if isinstance (message ["content" ], str ):
1388+ role_prefix = (
1389+ f"{ message ['role' ].capitalize ()} : "
1390+ if message ["role" ] != "system"
1391+ else ""
1392+ )
1393+ prompt += f"{ role_prefix } { message ['content' ]} \n "
1394+ else : # Handle content that might be a list (for multimodal)
1395+ for content_item in message ["content" ]:
1396+ if content_item .get ("type" ) == "text" :
1397+ role_prefix = (
1398+ f"{ message ['role' ].capitalize ()} : "
1399+ if message ["role" ] != "system"
1400+ else ""
1401+ )
1402+ prompt += f"{ role_prefix } { content_item ['text' ]} \n "
1403+ prompt += "\n DO NOT RESPONSE IN MARKDOWN or provide any citations"
1404+
1405+ # Log request parameters
1406+ gen_params = {
1407+ "model" : model_name ,
1408+ "prompt" : prompt ,
1409+ "temperature" : temperature ,
1410+ "top_p" : top_p ,
1411+ "max_new_tokens" : max_new_tokens ,
1412+ }
1413+ logger .info (f"==== request ====\n { gen_params } " )
1414+
1415+ # Create a queue for communication between threads
1416+ response_queue = queue .Queue ()
1417+ stop_event = threading .Event ()
1418+ connection_closed = threading .Event ()
1419+
1420+ # Thread function to handle WebSocket communication
1421+ def websocket_thread ():
1422+ import asyncio
1423+
1424+ async def connect_and_stream ():
1425+ uri = api_base or "wss://api.critique-labs.ai/v1/ws/search"
1426+
1427+ try :
1428+ # Create connection with headers in the correct format
1429+ async with websockets .connect (
1430+ uri , additional_headers = {"X-API-Key" : api_key }
1431+ ) as websocket :
1432+ # Send the search request
1433+ await websocket .send (
1434+ json .dumps (
1435+ {
1436+ "prompt" : prompt ,
1437+ }
1438+ )
1439+ )
1440+
1441+ # Receive and process streaming responses
1442+ while not stop_event .is_set ():
1443+ try :
1444+ response = await websocket .recv ()
1445+ data = json .loads (response )
1446+ response_queue .put (data )
1447+
1448+ # If we get an error, we're done
1449+ if data ["type" ] == "error" :
1450+ break
1451+ except websockets .exceptions .ConnectionClosed :
1452+ # This is the expected end signal - not an error
1453+ logger .info (
1454+ "WebSocket connection closed by server - this is the expected end signal"
1455+ )
1456+ connection_closed .set () # Signal that the connection was closed normally
1457+ break
1458+ except Exception as e :
1459+ # Only log as error for unexpected exceptions
1460+ logger .error (f"WebSocket error: { str (e )} " )
1461+ response_queue .put (
1462+ {"type" : "error" , "content" : f"WebSocket error: { str (e )} " }
1463+ )
1464+ finally :
1465+ # Always set connection_closed when we exit
1466+ connection_closed .set ()
1467+
1468+ asyncio .run (connect_and_stream ())
1469+
1470+ # Start the WebSocket thread
1471+ thread = threading .Thread (target = websocket_thread )
1472+ thread .daemon = True
1473+ thread .start ()
1474+
1475+ try :
1476+ text = ""
1477+ context_info = []
1478+
1479+ # Process responses from the queue until connection is closed
1480+ while not connection_closed .is_set () or not response_queue .empty ():
1481+ try :
1482+ # Wait for a response with timeout
1483+ data = response_queue .get (
1484+ timeout = 0.5
1485+ ) # Short timeout to check connection_closed frequently
1486+
1487+ if data ["type" ] == "response" :
1488+ text += data ["content" ]
1489+ yield {
1490+ "text" : text ,
1491+ "error_code" : 0 ,
1492+ }
1493+ elif data ["type" ] == "context" :
1494+ # Collect context information
1495+ context_info .append (data ["content" ])
1496+ elif data ["type" ] == "error" :
1497+ logger .error (f"Critique API error: { data ['content' ]} " )
1498+ yield {
1499+ "text" : f"**API REQUEST ERROR** Reason: { data ['content' ]} " ,
1500+ "error_code" : 1 ,
1501+ }
1502+ break
1503+
1504+ response_queue .task_done ()
1505+ except queue .Empty :
1506+ # Just a timeout to check if connection is closed
1507+ continue
1508+
1509+ except Exception as e :
1510+ logger .error (f"Error in critique_api_stream_iter: { str (e )} " )
1511+ yield {
1512+ "text" : f"**API REQUEST ERROR** Reason: { str (e )} " ,
1513+ "error_code" : 1 ,
1514+ }
1515+ finally :
1516+ # Signal the thread to stop and wait for it to finish
1517+ stop_event .set ()
1518+ thread .join (timeout = 5 )
0 commit comments