2727└──────────────────────────────────────────────────────────────────────────────┘
2828"""
2929
30+ import uuid
3031from fastapi import (
3132 APIRouter ,
3233 Depends ,
3334 HTTPException ,
3435 status ,
3536 WebSocket ,
3637 WebSocketDisconnect ,
38+ Header ,
3739)
3840from sqlalchemy .orm import Session
3941from src .config .database import get_db
5759from datetime import datetime
5860import logging
5961import json
62+ from typing import Optional , Dict
6063
6164logger = logging .getLogger (__name__ )
6265
6770)
6871
6972
73+ async def get_agent_by_api_key (
74+ agent_id : str ,
75+ api_key : Optional [str ] = Header (None , alias = "x-api-key" ),
76+ authorization : Optional [str ] = Header (None ),
77+ db : Session = Depends (get_db ),
78+ ):
79+ """Flexible authentication for chat routes, allowing JWT or API key"""
80+ if authorization :
81+ # Try to authenticate with JWT token first
82+ try :
83+ # Extract token from Authorization header if needed
84+ token = (
85+ authorization .replace ("Bearer " , "" )
86+ if authorization .startswith ("Bearer " )
87+ else authorization
88+ )
89+ payload = await get_jwt_token (token )
90+ agent = agent_service .get_agent (db , agent_id )
91+ if not agent :
92+ raise HTTPException (
93+ status_code = status .HTTP_404_NOT_FOUND ,
94+ detail = "Agent not found" ,
95+ )
96+
97+ # Verify if the user has access to the agent's client
98+ await verify_user_client (payload , db , agent .client_id )
99+ return agent
100+ except Exception as e :
101+ logger .warning (f"JWT authentication failed: { str (e )} " )
102+ # If JWT fails, continue to try with API key
103+
104+ # Try to authenticate with API key
105+ if not api_key :
106+ raise HTTPException (
107+ status_code = status .HTTP_401_UNAUTHORIZED ,
108+ detail = "Authentication required (JWT or API key)" ,
109+ )
110+
111+ agent = agent_service .get_agent (db , agent_id )
112+ if not agent or not agent .config :
113+ raise HTTPException (
114+ status_code = status .HTTP_404_NOT_FOUND , detail = "Agent not found"
115+ )
116+
117+ # Verify if the API key matches
118+ if not agent .config .get ("api_key" ) or agent .config .get ("api_key" ) != api_key :
119+ raise HTTPException (
120+ status_code = status .HTTP_401_UNAUTHORIZED , detail = "Invalid API key"
121+ )
122+
123+ return agent
124+
125+
70126@router .websocket ("/ws/{agent_id}/{external_id}" )
71127async def websocket_chat (
72128 websocket : WebSocket ,
@@ -82,32 +138,49 @@ async def websocket_chat(
82138 # Wait for authentication message
83139 try :
84140 auth_data = await websocket .receive_json ()
85- logger .info (f"Received authentication data: { auth_data } " )
141+ logger .info (f"Authentication data received : { auth_data } " )
86142
87143 if not (
88- auth_data .get ("type" ) == "authorization" and auth_data .get ("token" )
144+ auth_data .get ("type" ) == "authorization"
145+ and (auth_data .get ("token" ) or auth_data .get ("api_key" ))
89146 ):
90147 logger .warning ("Invalid authentication message" )
91148 await websocket .close (code = status .WS_1008_POLICY_VIOLATION )
92149 return
93150
94- token = auth_data ["token" ]
95- # Verify the token
96- payload = await get_jwt_token_ws (token )
97- if not payload :
98- logger .warning ("Invalid token" )
99- await websocket .close (code = status .WS_1008_POLICY_VIOLATION )
100- return
101-
102- # Verify if the agent belongs to the user's client
151+ # Verify if the agent exists
103152 agent = agent_service .get_agent (db , agent_id )
104153 if not agent :
105154 logger .warning (f"Agent { agent_id } not found" )
106155 await websocket .close (code = status .WS_1008_POLICY_VIOLATION )
107156 return
108157
109- # Verify if the user has access to the agent (via client)
110- await verify_user_client (payload , db , agent .client_id )
158+ # Verify authentication
159+ is_authenticated = False
160+
161+ # Try with JWT token
162+ if auth_data .get ("token" ):
163+ try :
164+ payload = await get_jwt_token_ws (auth_data ["token" ])
165+ if payload :
166+ # Verify if the user has access to the agent
167+ await verify_user_client (payload , db , agent .client_id )
168+ is_authenticated = True
169+ except Exception as e :
170+ logger .warning (f"JWT authentication failed: { str (e )} " )
171+
172+ # If JWT fails, try with API key
173+ if not is_authenticated and auth_data .get ("api_key" ):
174+ if agent .config and agent .config .get ("api_key" ) == auth_data .get (
175+ "api_key"
176+ ):
177+ is_authenticated = True
178+ else :
179+ logger .warning ("Invalid API key" )
180+
181+ if not is_authenticated :
182+ await websocket .close (code = status .WS_1008_POLICY_VIOLATION )
183+ return
111184
112185 logger .info (
113186 f"WebSocket connection established for agent { agent_id } and external_id { external_id } "
@@ -174,19 +247,9 @@ async def websocket_chat(
174247)
175248async def chat (
176249 request : ChatRequest ,
250+ _ = Depends (get_agent_by_api_key ),
177251 db : Session = Depends (get_db ),
178- payload : dict = Depends (get_jwt_token ),
179252):
180- # Verify if the agent belongs to the user's client
181- agent = agent_service .get_agent (db , request .agent_id )
182- if not agent :
183- raise HTTPException (
184- status_code = status .HTTP_404_NOT_FOUND , detail = "Agent not found"
185- )
186-
187- # Verify if the user has access to the agent (via client)
188- await verify_user_client (payload , db , agent .client_id )
189-
190253 try :
191254 final_response = await run_agent (
192255 request .agent_id ,
0 commit comments