|
| 1 | +import os |
| 2 | +from typing import List, Dict, Tuple, Optional |
| 3 | +from sqlalchemy import inspect, text |
| 4 | +from sqlmodel import SQLModel, create_engine, Session |
| 5 | +from langchain_openai import OpenAIEmbeddings |
| 6 | +from langchain.vectorstores import PGVector |
| 7 | +from langchain.schema import Document |
| 8 | +import logging |
| 9 | +from apps.chat.schemas.chat_base_schema import LLMConfig |
| 10 | +from common.core.config import settings |
| 11 | +from common.core.db import engine |
| 12 | + |
| 13 | +class SchemaEmbeddingManager: |
| 14 | + def __init__(self, config: LLMConfig): |
| 15 | + """Initialize SchemaEmbeddingManager |
| 16 | + |
| 17 | + Args: |
| 18 | + db_uri: Database connection URI in format postgresql://user:password@host:port/database |
| 19 | + embedding_model: Embedding model to use |
| 20 | + """ |
| 21 | + # self.db_uri = db_uri |
| 22 | + self.embedding_model = OpenAIEmbeddings( |
| 23 | + model=config.model_name, |
| 24 | + openai_api_base=config.api_base_url, |
| 25 | + api_key=config.api_key |
| 26 | + ) |
| 27 | + # self.engine = self._create_engine() |
| 28 | + self.engine = engine |
| 29 | + self._setup_vector_extension() |
| 30 | + |
| 31 | + |
| 32 | + def _setup_vector_extension(self): |
| 33 | + """Set up PgVector extension using SQLModel""" |
| 34 | + try: |
| 35 | + with Session(self.engine) as session: |
| 36 | + # Enable vector extension |
| 37 | + session.exec(text("CREATE EXTENSION IF NOT EXISTS vector;")) |
| 38 | + session.commit() |
| 39 | + |
| 40 | + # Create vector type if needed |
| 41 | + session.exec(text(""" |
| 42 | + DO $$ |
| 43 | + BEGIN |
| 44 | + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'vector') THEN |
| 45 | + CREATE TYPE vector AS (x float8[]); |
| 46 | + END IF; |
| 47 | + END |
| 48 | + $$; |
| 49 | + """)) |
| 50 | + session.commit() |
| 51 | + logging.info("Vector extension setup completed successfully") |
| 52 | + except Exception as e: |
| 53 | + logging.error(f"Failed to setup vector extension: {str(e)}") |
| 54 | + raise |
| 55 | + |
| 56 | + def get_session(self) -> Session: |
| 57 | + """Get a new database session""" |
| 58 | + return Session(self.engine) |
| 59 | + |
| 60 | + def extract_schema_info(self, schema_name: str = "public") -> List[Dict]: |
| 61 | + """ |
| 62 | + Extract database schema information and structure it |
| 63 | + |
| 64 | + Args: |
| 65 | + schema_name: Schema name to extract, defaults to public |
| 66 | + |
| 67 | + Returns: |
| 68 | + List of dictionaries containing schema information |
| 69 | + """ |
| 70 | + metadata = SQLModel.metadata |
| 71 | + metadata.reflect(bind=self.engine, schema=schema_name) |
| 72 | + |
| 73 | + schema_info = [] |
| 74 | + inspector = inspect(self.engine) |
| 75 | + |
| 76 | + for table_name in inspector.get_table_names(schema=schema_name): |
| 77 | + table = metadata.tables[f"{schema_name}.{table_name}"] |
| 78 | + table_comment = inspector.get_table_comment(table_name, schema=schema_name) |
| 79 | + |
| 80 | + # Table-level information |
| 81 | + table_info = { |
| 82 | + "type": "table", |
| 83 | + "schema": schema_name, |
| 84 | + "name": table_name, |
| 85 | + "description": table_comment.get('text') if table_comment else f"Table {table_name} in schema {schema_name}", |
| 86 | + "columns": [] |
| 87 | + } |
| 88 | + |
| 89 | + # Column-level information |
| 90 | + for column in inspector.get_columns(table_name, schema=schema_name): |
| 91 | + col_name = column['name'] |
| 92 | + # Get column comment directly from column info |
| 93 | + col_comment = column.get('comment', f"Column {col_name} in table {table_name}") |
| 94 | + |
| 95 | + column_info = { |
| 96 | + "type": "column", |
| 97 | + "schema": schema_name, |
| 98 | + "table": table_name, |
| 99 | + "name": col_name, |
| 100 | + "data_type": str(column['type']), |
| 101 | + "description": col_comment, |
| 102 | + "is_primary_key": col_name in [pk['name'] for pk in inspector.get_pk_constraint(table_name, schema=schema_name).get('constrained_columns', [])], |
| 103 | + "foreign_key": None # Will be updated below |
| 104 | + } |
| 105 | + |
| 106 | + # Get foreign key information |
| 107 | + for fk in inspector.get_foreign_keys(table_name, schema=schema_name): |
| 108 | + if col_name in fk['constrained_columns']: |
| 109 | + column_info["foreign_key"] = f"{fk['referred_table']}.{fk['referred_columns'][0]}" |
| 110 | + break |
| 111 | + |
| 112 | + table_info["columns"].append(column_info) |
| 113 | + |
| 114 | + schema_info.append(table_info) |
| 115 | + |
| 116 | + return schema_info |
| 117 | + |
| 118 | + def generate_schema_documents(self, schema_info: List[Dict]) -> List[Document]: |
| 119 | + """ |
| 120 | + Convert schema information into LangChain Document objects |
| 121 | + |
| 122 | + Args: |
| 123 | + schema_info: Schema information from extract_schema_info |
| 124 | + |
| 125 | + Returns: |
| 126 | + List of LangChain Document objects |
| 127 | + """ |
| 128 | + documents = [] |
| 129 | + |
| 130 | + for table in schema_info: |
| 131 | + # Table-level document |
| 132 | + table_content = f"Table {table['schema']}.{table['name']}: {table['description']}. " |
| 133 | + table_content += f"Contains columns: {', '.join(col['name'] for col in table['columns'])}" |
| 134 | + |
| 135 | + table_doc = Document( |
| 136 | + page_content=table_content, |
| 137 | + metadata={ |
| 138 | + "type": "table", |
| 139 | + "schema": table["schema"], |
| 140 | + "name": table["name"], |
| 141 | + "description": table["description"], |
| 142 | + "full_name": f"{table['schema']}.{table['name']}" |
| 143 | + } |
| 144 | + ) |
| 145 | + documents.append(table_doc) |
| 146 | + |
| 147 | + # Column-level documents |
| 148 | + for column in table["columns"]: |
| 149 | + column_content = f"Column {table['schema']}.{table['name']}.{column['name']}: {column['description']}. " |
| 150 | + column_content += f"Data type: {column['data_type']}. " |
| 151 | + if column["is_primary_key"]: |
| 152 | + column_content += "This is a primary key. " |
| 153 | + if column["foreign_key"]: |
| 154 | + column_content += f"Foreign key to {column['foreign_key']}." |
| 155 | + |
| 156 | + column_doc = Document( |
| 157 | + page_content=column_content, |
| 158 | + metadata={ |
| 159 | + "type": "column", |
| 160 | + "schema": table["schema"], |
| 161 | + "table": table["name"], |
| 162 | + "name": column["name"], |
| 163 | + "description": column["description"], |
| 164 | + "data_type": column["data_type"], |
| 165 | + "full_name": f"{table['schema']}.{table['name']}.{column['name']}" |
| 166 | + } |
| 167 | + ) |
| 168 | + documents.append(column_doc) |
| 169 | + |
| 170 | + return documents |
| 171 | + |
| 172 | + def store_embeddings(self, documents: List[Document], collection_name: str = "schema_embeddings"): |
| 173 | + """Store schema document embeddings in PgVector |
| 174 | + |
| 175 | + Args: |
| 176 | + documents: Documents from generate_schema_documents |
| 177 | + collection_name: Vector collection name |
| 178 | + """ |
| 179 | + try: |
| 180 | + PGVector.from_documents( |
| 181 | + embedding=self.embedding_model, |
| 182 | + documents=documents, |
| 183 | + collection_name=collection_name, |
| 184 | + connection_string=self.db_uri, |
| 185 | + pre_delete_collection=True # Optionally recreate collection |
| 186 | + ) |
| 187 | + logging.info(f"Successfully stored {len(documents)} embeddings in collection {collection_name}") |
| 188 | + except Exception as e: |
| 189 | + logging.error(f"Failed to store embeddings: {str(e)}") |
| 190 | + raise |
| 191 | + |
| 192 | + def get_relevant_schema( |
| 193 | + self, |
| 194 | + query: str, |
| 195 | + collection_name: str = "schema_embeddings", |
| 196 | + top_k: int = 5 |
| 197 | + ) -> Tuple[List[Dict], List[Dict]]: |
| 198 | + """Get relevant schema information based on user query |
| 199 | + |
| 200 | + Args: |
| 201 | + query: User query text |
| 202 | + collection_name: Vector collection name |
| 203 | + top_k: Number of most relevant items to return |
| 204 | + |
| 205 | + Returns: |
| 206 | + Tuple containing (relevant_tables_list, relevant_columns_list) |
| 207 | + """ |
| 208 | + try: |
| 209 | + store = PGVector( |
| 210 | + collection_name=collection_name, |
| 211 | + connection_string=self.db_uri, |
| 212 | + embedding_function=self.embedding_model, |
| 213 | + ) |
| 214 | + |
| 215 | + # Execute similarity search |
| 216 | + docs = store.similarity_search_with_score(query, k=top_k) |
| 217 | + |
| 218 | + relevant_tables = [] |
| 219 | + relevant_columns = [] |
| 220 | + seen_tables = set() |
| 221 | + |
| 222 | + for doc, score in docs: |
| 223 | + metadata = doc.metadata |
| 224 | + item = { |
| 225 | + "name": metadata["full_name"], |
| 226 | + "description": metadata["description"], |
| 227 | + "type": metadata["type"], |
| 228 | + "score": float(score) |
| 229 | + } |
| 230 | + |
| 231 | + if metadata["type"] == "table" and metadata["full_name"] not in seen_tables: |
| 232 | + relevant_tables.append(item) |
| 233 | + seen_tables.add(metadata["full_name"]) |
| 234 | + elif metadata["type"] == "column": |
| 235 | + relevant_columns.append(item) |
| 236 | + # Ensure associated table is also included |
| 237 | + table_name = f"{metadata['schema']}.{metadata['table']}" |
| 238 | + if table_name not in seen_tables: |
| 239 | + relevant_tables.append({ |
| 240 | + "name": table_name, |
| 241 | + "description": f"Table containing column {metadata['name']}", |
| 242 | + "type": "table", |
| 243 | + "score": float(score) * 0.9 # Slightly lower table score |
| 244 | + }) |
| 245 | + seen_tables.add(table_name) |
| 246 | + |
| 247 | + return relevant_tables, relevant_columns |
| 248 | + |
| 249 | + except Exception as e: |
| 250 | + logging.error(f"Error during schema search: {str(e)}") |
| 251 | + raise |
| 252 | + |
| 253 | + def generate_schema_context(self, relevant_tables: List[Dict], relevant_columns: List[Dict]) -> str: |
| 254 | + """ |
| 255 | + Generate concise schema context for LLM usage |
| 256 | + |
| 257 | + Args: |
| 258 | + relevant_tables: List of relevant tables |
| 259 | + relevant_columns: List of relevant columns |
| 260 | + |
| 261 | + Returns: |
| 262 | + Formatted schema context string |
| 263 | + """ |
| 264 | + context = "## Relevant Database Schema Information:\n\n" |
| 265 | + |
| 266 | + if relevant_tables: |
| 267 | + context += "### Tables:\n" |
| 268 | + for table in sorted(relevant_tables, key=lambda x: -x["score"]): |
| 269 | + context += f"- {table['name']}: {table['description']} (relevance score: {table['score']:.2f})\n" |
| 270 | + |
| 271 | + if relevant_columns: |
| 272 | + context += "\n### Columns:\n" |
| 273 | + for column in sorted(relevant_columns, key=lambda x: -x["score"]): |
| 274 | + context += f"- {column['name']}: {column['description']} (relevance score: {column['score']:.2f})\n" |
| 275 | + |
| 276 | + return context |
0 commit comments