Skip to content

Commit c98583f

Browse files
committed
fix: address code audit issues - security, error handling, config
Critical fixes: - Fix broad exception handling in startup (raise instead of silent pass) - Add input validation to API models (field constraints with Annotated) - Improve connection pool configuration (min_size=2, max_size=20, timeout) High priority fixes: - Add rate limiting to API endpoints (30/min for search, 10/min for chat) - Extract hardcoded retry values to config (RETRY_MAX_ATTEMPTS, etc.) - Fix error handling to avoid exposing stack traces (use logger.error) Medium priority fixes: - Make CORS configurable (CORS_ORIGINS, allow_credentials, etc.) - Add request logging middleware for audit trail - Improve health check with dependency validation (database, embeddings, chat_agent) Low priority fixes: - Fix import order across project using ruff --select I --fix - Enable stricter mypy checks (check_untyped_defs, strict_optional) Configuration changes: - Add APIConfig with CORS settings - Add RetryConfig with retry parameters - Update .env.example with new configuration options Dependencies: - Add slowapi>=0.1.9 for rate limiting
1 parent fbb2c5d commit c98583f

File tree

7 files changed

+194
-92
lines changed

7 files changed

+194
-92
lines changed

.env.example

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,14 @@ SCRAPE_USER_AGENT=Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)
3030

3131
# Application Configuration
3232
LOG_LEVEL=INFO
33+
34+
# CORS Configuration
35+
CORS_ORIGINS=*
36+
CORS_ALLOW_CREDENTIALS=true
37+
CORS_ALLOW_METHODS=*
38+
CORS_ALLOW_HEADERS=*
39+
40+
# Retry Configuration
41+
RETRY_MAX_ATTEMPTS=7
42+
RETRY_START_DELAY=10.0
43+
RETRY_DELAY_INCREMENT=1.0

api/search_api.py

Lines changed: 132 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,72 @@
44

55
import asyncio
66
import json
7+
import logging
78
from datetime import datetime
8-
from typing import Any
9+
from typing import Annotated, Any
910

10-
from fastapi import FastAPI, HTTPException
11-
from fastapi.responses import StreamingResponse
11+
from fastapi import FastAPI, HTTPException, Request
1212
from fastapi.middleware.cors import CORSMiddleware
13+
from fastapi.responses import StreamingResponse
1314
from fastapi.staticfiles import StaticFiles
14-
from pydantic import BaseModel
15+
from pydantic import BaseModel, Field
16+
from slowapi import Limiter, _rate_limit_exceeded_handler
17+
from slowapi.errors import RateLimitExceeded
18+
from slowapi.util import get_remote_address
1519

16-
from lib.db.postgres_client import PostgresClient
17-
from lib.db.chat_schema import ensure_chat_schema
18-
from lib.chat_agent_v2 import KGChatAgentV2
1920
from lib.advanced_search_features import AdvancedSearchFeatures
21+
from lib.chat_agent_v2 import KGChatAgentV2
22+
from lib.db.chat_schema import ensure_chat_schema
23+
from lib.db.postgres_client import PostgresClient
2024
from lib.embeddings.google_client import GoogleEmbeddingClient
2125
from lib.utils.config import config
2226

27+
logger = logging.getLogger(__name__)
28+
29+
limiter = Limiter(key_func=get_remote_address)
2330
app = FastAPI(title="Parliamentary Search API")
31+
app.state.limiter = limiter
32+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
2433

34+
cors_origins = config.api.cors_origins.split(",") if config.api.cors_origins != "*" else ["*"]
2535
app.add_middleware(
2636
CORSMiddleware,
27-
allow_origins=["*"],
28-
allow_credentials=True,
29-
allow_methods=["*"],
30-
allow_headers=["*"],
37+
allow_origins=cors_origins,
38+
allow_credentials=config.api.cors_allow_credentials,
39+
allow_methods=(
40+
config.api.cors_allow_methods.split(",") if config.api.cors_allow_methods != "*" else ["*"]
41+
),
42+
allow_headers=(
43+
config.api.cors_allow_headers.split(",") if config.api.cors_allow_headers != "*" else ["*"]
44+
),
3145
)
3246

3347

48+
@app.middleware("http")
49+
async def log_requests(request: Request, call_next):
50+
"""Log all incoming requests for audit trail."""
51+
start_time = datetime.now()
52+
53+
logger.info(
54+
f"Request: {request.method} {request.url.path} from {request.client.host if request.client else 'unknown'}"
55+
)
56+
57+
try:
58+
response = await call_next(request)
59+
duration = (datetime.now() - start_time).total_seconds()
60+
logger.info(
61+
f"Response: {response.status_code} for {request.method} {request.url.path} - {duration:.3f}s"
62+
)
63+
return response
64+
except Exception as e:
65+
duration = (datetime.now() - start_time).total_seconds()
66+
logger.error(
67+
f"Request failed: {request.method} {request.url.path} - {duration:.3f}s - {e}",
68+
exc_info=True,
69+
)
70+
raise
71+
72+
3473
def _get_postgres() -> PostgresClient:
3574
assert postgres is not None
3675
return postgres
@@ -57,8 +96,9 @@ def _startup() -> None:
5796
postgres = PostgresClient()
5897
try:
5998
ensure_chat_schema(postgres)
60-
except Exception:
61-
pass
99+
except Exception as e:
100+
print(f"❌ Failed to ensure chat schema: {e}")
101+
raise
62102
embedding_client = GoogleEmbeddingClient()
63103
advanced_search = AdvancedSearchFeatures(
64104
postgres=postgres,
@@ -73,9 +113,9 @@ def _startup() -> None:
73113

74114

75115
class SearchRequest(BaseModel):
76-
query: str
77-
limit: int = 20
78-
alpha: float = 0.6
116+
query: Annotated[str, Field(min_length=1, max_length=500)]
117+
limit: Annotated[int, Field(ge=1, le=100)] = 20
118+
alpha: Annotated[float, Field(ge=0.0, le=1.0)] = 0.6
79119

80120

81121
class SearchResult(BaseModel):
@@ -95,9 +135,9 @@ class SearchResult(BaseModel):
95135

96136

97137
class TemporalSearchRequest(BaseModel):
98-
query: str
99-
limit: int = 20
100-
alpha: float | None = 0.6
138+
query: Annotated[str, Field(min_length=1, max_length=500)]
139+
limit: Annotated[int, Field(ge=1, le=100)] = 20
140+
alpha: Annotated[float | None, Field(ge=0.0, le=1.0)] = 0.6
101141
start_date: str | None = None
102142
end_date: str | None = None
103143
speaker_id: str | None = None
@@ -221,7 +261,7 @@ class ChatUsedEdge(BaseModel):
221261

222262

223263
class ChatMessageRequest(BaseModel):
224-
content: str
264+
content: Annotated[str, Field(min_length=1, max_length=5000)]
225265

226266

227267
class ChatMessageResponse(BaseModel):
@@ -414,14 +454,18 @@ def retrieve_sentences_for_paragraphs(paragraph_ids: list[str]) -> list[dict[str
414454

415455

416456
@app.post("/search", response_model=list[SearchResult])
417-
async def search(request: SearchRequest):
457+
@limiter.limit("30/minute")
458+
async def search(request: Request, search_request: SearchRequest):
418459
"""Hybrid search combining entity + paragraph vector search."""
419460
try:
420461
try:
421-
query_embedding = _get_embedding_client().generate_query_embedding(request.query)
462+
query_embedding = _get_embedding_client().generate_query_embedding(search_request.query)
422463
except Exception as e:
423464
print(f"⚠️ Embeddings unavailable; falling back to BM25 only: {e}")
424-
return [SearchResult(**r) for r in bm25_search_sentences(request.query, request.limit)]
465+
return [
466+
SearchResult(**r)
467+
for r in bm25_search_sentences(search_request.query, search_request.limit)
468+
]
425469

426470
phase1_entities = vector_search_entities(query_embedding, 10)
427471
phase1_paragraphs = vector_search_paragraphs(query_embedding, 10)
@@ -466,43 +510,45 @@ async def search(request: SearchRequest):
466510
score=float(r["score"]),
467511
search_type="hybrid",
468512
)
469-
for r in scored[: request.limit]
513+
for r in scored[: search_request.limit]
470514
]
471515

472516
except Exception as e:
473-
raise HTTPException(status_code=500, detail=str(e))
517+
logger.error(f"Search failed: {e}", exc_info=True)
518+
raise HTTPException(status_code=500, detail="Internal server error")
474519

475520

476521
@app.post("/search/temporal", response_model=list[SearchResult])
477-
async def temporal_search(request: TemporalSearchRequest):
522+
@limiter.limit("30/minute")
523+
async def temporal_search(request: Request, temporal_request: TemporalSearchRequest):
478524
"""Temporal search with filters."""
479525
try:
480526
try:
481527
results = _get_advanced_search().temporal_search(
482-
request.query,
483-
request.start_date,
484-
request.end_date,
485-
request.speaker_id,
486-
request.entity_type,
487-
request.limit,
528+
temporal_request.query,
529+
temporal_request.start_date,
530+
temporal_request.end_date,
531+
temporal_request.speaker_id,
532+
temporal_request.entity_type,
533+
temporal_request.limit,
488534
)
489535
return [SearchResult(**r) for r in results]
490536
except Exception as e:
491537
# If embeddings are not available, fall back to BM25 filtered by dates/speaker.
492538
print(f"⚠️ Temporal embeddings unavailable; using BM25 fallback: {e}")
493539

494540
where = ["s.tsv @@ plainto_tsquery('english', %s)"]
495-
params: list[Any] = [request.query]
541+
params: list[Any] = [temporal_request.query]
496542

497-
if request.start_date:
543+
if temporal_request.start_date:
498544
where.append("s.video_date >= to_date(%s, 'YYYY-MM-DD')")
499-
params.append(request.start_date)
500-
if request.end_date:
545+
params.append(temporal_request.start_date)
546+
if temporal_request.end_date:
501547
where.append("s.video_date <= to_date(%s, 'YYYY-MM-DD')")
502-
params.append(request.end_date)
503-
if request.speaker_id:
548+
params.append(temporal_request.end_date)
549+
if temporal_request.speaker_id:
504550
where.append("s.speaker_id = %s")
505-
params.append(request.speaker_id)
551+
params.append(temporal_request.speaker_id)
506552

507553
sql = f"""
508554
SELECT
@@ -524,8 +570,8 @@ async def temporal_search(request: TemporalSearchRequest):
524570
LIMIT %s
525571
"""
526572
# rank query param must be first; reuse query as last before limit
527-
rank_query = request.query
528-
final_params = [rank_query, *params, request.limit]
573+
rank_query = temporal_request.query
574+
final_params = [rank_query, *params, temporal_request.limit]
529575
rows = _get_postgres().execute_query(sql, tuple(final_params))
530576

531577
return [
@@ -546,7 +592,8 @@ async def temporal_search(request: TemporalSearchRequest):
546592
for row in rows
547593
]
548594
except Exception as e:
549-
raise HTTPException(status_code=500, detail=str(e))
595+
logger.error(f"Temporal search failed: {e}", exc_info=True)
596+
raise HTTPException(status_code=500, detail="Internal server error")
550597

551598

552599
@app.get("/search/trends")
@@ -568,7 +615,8 @@ async def get_trends(
568615
moving_average=result["moving_average"],
569616
)
570617
except Exception as e:
571-
raise HTTPException(status_code=500, detail=str(e))
618+
logger.error(f"Trends retrieval failed: {e}", exc_info=True)
619+
raise HTTPException(status_code=500, detail="Internal server error")
572620

573621

574622
@app.get("/speakers")
@@ -607,7 +655,8 @@ async def get_speakers() -> list[Speaker]:
607655
for row in results
608656
]
609657
except Exception as e:
610-
raise HTTPException(status_code=500, detail=str(e))
658+
logger.error(f"Get speakers failed: {e}", exc_info=True)
659+
raise HTTPException(status_code=500, detail="Internal server error")
611660

612661

613662
@app.get("/speakers/{speaker_id}")
@@ -678,7 +727,8 @@ async def get_speaker_stats(speaker_id: str) -> SpeakerStatsResponse:
678727
except HTTPException:
679728
raise
680729
except Exception as e:
681-
raise HTTPException(status_code=500, detail=str(e))
730+
logger.error(f"Get speaker stats failed: {e}", exc_info=True)
731+
raise HTTPException(status_code=500, detail="Internal server error")
682732

683733

684734
@app.get(
@@ -729,47 +779,17 @@ async def create_thread(title: str | None = None):
729779
created_at=str(datetime.now()),
730780
)
731781
except Exception as e:
732-
raise HTTPException(status_code=500, detail=str(e))
733-
734-
735-
@app.get("/chat/threads/{thread_id}", response_model=GetThreadResponse)
736-
async def get_thread(thread_id: str):
737-
"""Get thread metadata and messages."""
738-
try:
739-
agent = _get_chat_agent()
740-
thread = agent.get_thread(thread_id)
741-
if thread is None:
742-
raise HTTPException(status_code=404, detail="Thread not found")
743-
744-
return GetThreadResponse(
745-
id=thread["id"],
746-
title=thread["title"],
747-
created_at=thread["created_at"],
748-
updated_at=thread["updated_at"],
749-
state=thread["state"],
750-
messages=[
751-
ThreadMessage(
752-
id=m["id"],
753-
role=m["role"],
754-
content=m["content"],
755-
metadata=m.get("metadata"),
756-
created_at=m["created_at"],
757-
)
758-
for m in thread["messages"]
759-
],
760-
)
761-
except HTTPException:
762-
raise
763-
except Exception as e:
764-
raise HTTPException(status_code=500, detail=str(e))
782+
logger.error(f"Create thread failed: {e}", exc_info=True)
783+
raise HTTPException(status_code=500, detail="Internal server error")
765784

766785

767786
@app.post("/chat/threads/{thread_id}/messages", response_model=ChatMessageResponse)
768-
async def send_message(thread_id: str, request: ChatMessageRequest):
787+
@limiter.limit("10/minute")
788+
async def send_message(request: Request, thread_id: str, chat_request: ChatMessageRequest):
769789
"""Send a message to a thread and get assistant response."""
770790
try:
771791
agent = _get_chat_agent()
772-
response = await agent.process_message(thread_id, request.content)
792+
response = await agent.process_message(thread_id, chat_request.content)
773793

774794
return ChatMessageResponse(
775795
thread_id=thread_id,
@@ -782,7 +802,8 @@ async def send_message(thread_id: str, request: ChatMessageRequest):
782802
except ValueError as e:
783803
raise HTTPException(status_code=404, detail=str(e))
784804
except Exception as e:
785-
raise HTTPException(status_code=500, detail=str(e))
805+
logger.error(f"Send message failed: {e}", exc_info=True)
806+
raise HTTPException(status_code=500, detail="Internal server error")
786807

787808

788809
async def stream_chat_response(thread_id: str, content: str):
@@ -851,7 +872,8 @@ async def run_agent():
851872

852873

853874
@app.get("/chat/threads/{thread_id}/messages/stream")
854-
async def stream_message(thread_id: str, content: str):
875+
@limiter.limit("10/minute")
876+
async def stream_message(request: Request, thread_id: str, content: str):
855877
"""Stream a message response with progress updates via SSE."""
856878
return StreamingResponse(
857879
stream_chat_response(thread_id, content),
@@ -866,8 +888,34 @@ async def stream_message(thread_id: str, content: str):
866888

867889
@app.get("/health")
868890
async def health():
869-
"""Health check endpoint for deployment monitoring."""
870-
return {"status": "ok", "timestamp": datetime.now().isoformat()}
891+
"""Health check endpoint for deployment monitoring with dependency validation."""
892+
health_status = {"status": "ok", "timestamp": datetime.now().isoformat(), "checks": {}}
893+
894+
try:
895+
_get_postgres().execute_query("SELECT 1")
896+
health_status["checks"]["database"] = "ok"
897+
except Exception as e:
898+
health_status["status"] = "degraded"
899+
health_status["checks"]["database"] = f"error: {e}"
900+
901+
try:
902+
if embedding_client:
903+
_get_embedding_client().generate_query_embedding("test")
904+
health_status["checks"]["embeddings"] = "ok"
905+
else:
906+
health_status["checks"]["embeddings"] = "skipped"
907+
except Exception as e:
908+
health_status["status"] = "degraded"
909+
health_status["checks"]["embeddings"] = f"error: {e}"
910+
911+
try:
912+
_get_chat_agent()
913+
health_status["checks"]["chat_agent"] = "ok"
914+
except Exception as e:
915+
health_status["status"] = "degraded"
916+
health_status["checks"]["chat_agent"] = f"error: {e}"
917+
918+
return health_status
871919

872920

873921
@app.get("/api")

0 commit comments

Comments
 (0)