Skip to content

Commit 8cc6695

Browse files
authored
✨ feat: Optimize PostgreSQL Queries and Add Optional Detailed Query Logging (#172)
* fix: vector index creation in database service on `cmetadata->>'file_id'` due to upstream query change * ✨ feat: Optimize PostgreSQL Queries and Add Optional Detailed Query Logging - Introduced a new environment variable `DEBUG_PGVECTOR_QUERIES` to enable detailed logging of pgvector operations. - Implemented query logging setup in the ExtendedPgVector class, capturing execution time and parameters for relevant queries. - Updated README.md to document the new environment variable.
1 parent aa5d89f commit 8cc6695

File tree

5 files changed

+122
-24
lines changed

5 files changed

+122
-24
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ The following environment variables are required to run the application:
6161
- `RAG_UPLOAD_DIR`: (Optional) The directory where uploaded files are stored. Default value is "./uploads/".
6262
- `PDF_EXTRACT_IMAGES`: (Optional) A boolean value indicating whether to extract images from PDF files. Default value is "False".
6363
- `DEBUG_RAG_API`: (Optional) Set to "True" to show more verbose logging output in the server console, and to enable postgresql database routes
64+
- `DEBUG_PGVECTOR_QUERIES`: (Optional) Set to "True" to enable detailed PostgreSQL query logging for pgvector operations. Useful for debugging performance issues with vector database queries.
6465
- `CONSOLE_JSON`: (Optional) Set to "True" to log as json for Cloud Logging aggregations
6566
- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "bedrock", "azure", "huggingface", "huggingfacetei", "vertexai", or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai"
6667
- `EMBEDDINGS_MODEL`: (Optional) Set a valid embeddings model to use from the configured provider.

app/services/database.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import asyncpg
33
from app.config import DSN, logger
44

5+
56
class PSQLDatabase:
67
pool = None
78

@@ -17,18 +18,29 @@ async def close_pool(cls):
1718
await cls.pool.close()
1819
cls.pool = None
1920

20-
async def ensure_custom_id_index_on_embedding():
21+
22+
async def ensure_vector_indexes():
2123
table_name = "langchain_pg_embedding"
2224
column_name = "custom_id"
2325
# You might want to standardize the index naming convention
2426
index_name = f"idx_{table_name}_{column_name}"
2527

2628
pool = await PSQLDatabase.get_pool()
2729
async with pool.acquire() as conn:
28-
await conn.execute(f"""
29-
CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} ({column_name});
30-
""")
31-
logger.debug(f"Checking if index '{index_name}' on '{table_name}({column_name}) exists, if not found then the index is created.'")
30+
await conn.execute(
31+
f"""
32+
CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} ({column_name});
33+
"""
34+
)
35+
36+
await conn.execute(
37+
f"""
38+
CREATE INDEX IF NOT EXISTS idx_{table_name}_file_id
39+
ON {table_name} ((cmetadata->>'file_id'));
40+
"""
41+
)
42+
43+
logger.info("Vector database indexes ensured")
3244

3345

3446
async def pg_health_check() -> bool:
@@ -39,4 +51,4 @@ async def pg_health_check() -> bool:
3951
return True
4052
except Exception as e:
4153
logger.error(f"Health check failed: {e}")
42-
return False
54+
return False

app/services/vector_store/extended_pg_vector.py

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,73 @@
1+
import os
2+
import time
3+
import logging
14
from typing import Optional
2-
5+
from sqlalchemy import event
36
from sqlalchemy import delete
47
from sqlalchemy.orm import Session
8+
from sqlalchemy.engine import Engine
59
from langchain_core.documents import Document
610
from langchain_community.vectorstores.pgvector import PGVector
711

12+
813
class ExtendedPgVector(PGVector):
14+
_query_logging_setup = False
15+
16+
def __init__(self, *args, **kwargs):
17+
super().__init__(*args, **kwargs)
18+
self.setup_query_logging()
19+
20+
def setup_query_logging(self):
21+
"""Enable query logging for this vector store only if DEBUG_PGVECTOR_QUERIES is set"""
22+
# Only setup logging if the environment variable is set to a truthy value
23+
debug_queries = os.getenv("DEBUG_PGVECTOR_QUERIES", "").lower()
24+
if debug_queries not in ["true", "1", "yes", "on"]:
25+
return
26+
27+
# Only setup once per class
28+
if ExtendedPgVector._query_logging_setup:
29+
return
30+
31+
logger = logging.getLogger("pgvector.queries")
32+
logger.setLevel(logging.INFO)
33+
34+
# Create handler if it doesn't exist
35+
if not logger.handlers:
36+
handler = logging.StreamHandler()
37+
formatter = logging.Formatter("%(asctime)s - PGVECTOR QUERY - %(message)s")
38+
handler.setFormatter(formatter)
39+
logger.addHandler(handler)
40+
41+
@event.listens_for(Engine, "before_cursor_execute")
42+
def receive_before_cursor_execute(
43+
conn, cursor, statement, parameters, context, executemany
44+
):
45+
if "langchain_pg_embedding" in statement:
46+
context._query_start_time = time.time()
47+
logger.info(f"STARTING QUERY: {statement}")
48+
logger.info(f"PARAMETERS: {parameters}")
49+
50+
@event.listens_for(Engine, "after_cursor_execute")
51+
def receive_after_cursor_execute(
52+
conn, cursor, statement, parameters, context, executemany
53+
):
54+
if "langchain_pg_embedding" in statement:
55+
total = time.time() - context._query_start_time
56+
logger.info(f"COMPLETED QUERY in {total:.4f}s")
57+
logger.info("-" * 50)
58+
59+
ExtendedPgVector._query_logging_setup = True
60+
961
def get_all_ids(self) -> list[str]:
1062
with Session(self._bind) as session:
1163
results = session.query(self.EmbeddingStore.custom_id).all()
1264
return [result[0] for result in results if result[0] is not None]
13-
65+
1466
def get_filtered_ids(self, ids: list[str]) -> list[str]:
1567
with Session(self._bind) as session:
16-
query = session.query(self.EmbeddingStore.custom_id).filter(self.EmbeddingStore.custom_id.in_(ids))
68+
query = session.query(self.EmbeddingStore.custom_id).filter(
69+
self.EmbeddingStore.custom_id.in_(ids)
70+
)
1771
results = query.all()
1872
return [result[0] for result in results if result[0] is not None]
1973

@@ -45,7 +99,9 @@ def _delete_multiple(
4599
if not collection:
46100
self.logger.warning("Collection not found")
47101
return
48-
stmt = stmt.where(self.EmbeddingStore.collection_id == collection.uuid)
102+
stmt = stmt.where(
103+
self.EmbeddingStore.collection_id == collection.uuid
104+
)
49105
stmt = stmt.where(self.EmbeddingStore.custom_id.in_(ids))
50106
session.execute(stmt)
51-
session.commit()
107+
session.commit()

main.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,49 @@
99

1010
from starlette.responses import JSONResponse
1111

12-
from app.config import VectorDBType, debug_mode, RAG_HOST, RAG_PORT, CHUNK_SIZE, CHUNK_OVERLAP, PDF_EXTRACT_IMAGES, VECTOR_DB_TYPE, \
13-
LogMiddleware, logger
12+
from app.config import (
13+
VectorDBType,
14+
debug_mode,
15+
RAG_HOST,
16+
RAG_PORT,
17+
CHUNK_SIZE,
18+
CHUNK_OVERLAP,
19+
PDF_EXTRACT_IMAGES,
20+
VECTOR_DB_TYPE,
21+
LogMiddleware,
22+
logger,
23+
)
1424
from app.middleware import security_middleware
1525
from app.routes import document_routes, pgvector_routes
16-
from app.services.database import PSQLDatabase, ensure_custom_id_index_on_embedding
26+
from app.services.database import PSQLDatabase, ensure_vector_indexes
27+
1728

1829
@asynccontextmanager
1930
async def lifespan(app: FastAPI):
2031
# Startup logic goes here
2132
# Create bounded thread pool executor based on CPU cores
22-
max_workers = min(int(os.getenv("RAG_THREAD_POOL_SIZE", str(os.cpu_count()))), 8) # Cap at 8
23-
app.state.thread_pool = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="rag-worker")
24-
logger.info(f"Initialized thread pool with {max_workers} workers (CPU cores: {os.cpu_count()})")
25-
33+
max_workers = min(
34+
int(os.getenv("RAG_THREAD_POOL_SIZE", str(os.cpu_count()))), 8
35+
) # Cap at 8
36+
app.state.thread_pool = ThreadPoolExecutor(
37+
max_workers=max_workers, thread_name_prefix="rag-worker"
38+
)
39+
logger.info(
40+
f"Initialized thread pool with {max_workers} workers (CPU cores: {os.cpu_count()})"
41+
)
42+
2643
if VECTOR_DB_TYPE == VectorDBType.PGVECTOR:
2744
await PSQLDatabase.get_pool() # Initialize the pool
28-
await ensure_custom_id_index_on_embedding()
45+
await ensure_vector_indexes()
2946

3047
yield
31-
48+
3249
# Cleanup logic
3350
logger.info("Shutting down thread pool")
3451
app.state.thread_pool.shutdown(wait=True)
3552
logger.info("Thread pool shutdown complete")
3653

54+
3755
app = FastAPI(lifespan=lifespan, debug=debug_mode)
3856

3957
app.add_middleware(
@@ -74,5 +92,6 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
7492
},
7593
)
7694

95+
7796
if __name__ == "__main__":
7897
uvicorn.run(app, host=RAG_HOST, port=RAG_PORT, log_config=None)

tests/services/test_database.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,49 @@
11
import pytest
2-
from app.services.database import ensure_custom_id_index_on_embedding, PSQLDatabase
2+
from app.services.database import ensure_vector_indexes, PSQLDatabase
3+
34

45
# Create dummy classes to simulate a database connection and pool
56
class DummyConnection:
67
async def fetchval(self, query, index_name):
78
# Simulate that the index does not exist
89
return False
10+
911
async def execute(self, query):
1012
return "Executed"
1113

14+
1215
class DummyAcquire:
1316
async def __aenter__(self):
1417
return DummyConnection()
18+
1519
async def __aexit__(self, exc_type, exc, tb):
1620
pass
1721

22+
1823
class DummyPool:
1924
def acquire(self):
2025
return DummyAcquire()
2126

27+
2228
class DummyDatabase:
2329
pool = DummyPool()
2430

2531
@classmethod
2632
async def get_pool(cls):
2733
return cls.pool
2834

35+
2936
@pytest.fixture
3037
def dummy_pool(monkeypatch):
3138
monkeypatch.setattr(PSQLDatabase, "get_pool", DummyDatabase.get_pool)
3239
return DummyPool()
3340

41+
3442
import asyncio
43+
44+
3545
@pytest.mark.asyncio
36-
async def test_ensure_custom_id_index_on_embedding(monkeypatch, dummy_pool):
37-
result = await ensure_custom_id_index_on_embedding()
46+
async def test_ensure_vector_indexes(monkeypatch, dummy_pool):
47+
result = await ensure_vector_indexes()
3848
# If no exceptions are raised, the function worked as expected.
39-
assert result is None
49+
assert result is None

0 commit comments

Comments
 (0)