Skip to content

Commit fdc6874

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents 85c3564 + a881d47 commit fdc6874

File tree

17 files changed

+880
-27
lines changed

17 files changed

+880
-27
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,5 @@ cython_debug/
182182
# PyPI configuration file
183183
.pypirc
184184

185-
.DS_Store
185+
.DS_Store
186+
test.py
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""005_table_and_field
2+
3+
Revision ID: 0a6f11be9be4
4+
Revises: 8fe654655905
5+
Create Date: 2025-05-15 10:20:25.686576
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '0a6f11be9be4'
15+
down_revision = '8fe654655905'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.create_table('core_field',
23+
sa.Column('id', sa.Integer(), sa.Identity(always=True), nullable=False),
24+
sa.Column('ds_id', sa.BigInteger(), nullable=True),
25+
sa.Column('table_id', sa.BigInteger(), nullable=True),
26+
sa.Column('checked', sa.Boolean(), nullable=False),
27+
sa.Column('field_name', sa.Text(), nullable=True),
28+
sa.Column('field_type', sqlmodel.sql.sqltypes.AutoString(length=128), nullable=True),
29+
sa.Column('field_comment', sa.Text(), nullable=True),
30+
sa.Column('custom_comment', sa.Text(), nullable=True),
31+
sa.PrimaryKeyConstraint('id')
32+
)
33+
op.create_table('core_table',
34+
sa.Column('id', sa.Integer(), sa.Identity(always=True), nullable=False),
35+
sa.Column('ds_id', sa.BigInteger(), nullable=True),
36+
sa.Column('checked', sa.Boolean(), nullable=False),
37+
sa.Column('table_name', sa.Text(), nullable=True),
38+
sa.Column('table_comment', sa.Text(), nullable=True),
39+
sa.Column('custom_comment', sa.Text(), nullable=True),
40+
sa.PrimaryKeyConstraint('id')
41+
)
42+
# ### end Alembic commands ###
43+
44+
45+
def downgrade():
46+
# ### commands auto generated by Alembic - please adjust! ###
47+
op.drop_table('core_table')
48+
op.drop_table('core_field')
49+
# ### end Alembic commands ###

backend/apps/chat/api/chat.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
from sqlmodel import select
44
from apps.chat.schemas.chat_base_schema import LLMConfig
55
from apps.chat.schemas.chat_schema import ChatQuestion
6-
from apps.chat.schemas.llm import AgentService, LLMService
6+
from apps.chat.schemas.llm import AgentService
77
from apps.datasource.models.datasource import CoreDatasource
88
from apps.system.models.system_model import AiModelDetail
99
from common.core.deps import SessionDep
10-
from sse_starlette.sse import EventSourceResponse
1110
import json
1211
import asyncio
1312

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import os
2+
from typing import List, Dict, Tuple, Optional
3+
from sqlalchemy import inspect, text
4+
from sqlmodel import SQLModel, create_engine, Session
5+
from langchain_openai import OpenAIEmbeddings
6+
from langchain.vectorstores import PGVector
7+
from langchain.schema import Document
8+
import logging
9+
from apps.chat.schemas.chat_base_schema import LLMConfig
10+
from common.core.config import settings
11+
from common.core.db import engine
12+
13+
class SchemaEmbeddingManager:
14+
def __init__(self, config: LLMConfig):
15+
"""Initialize SchemaEmbeddingManager
16+
17+
Args:
18+
db_uri: Database connection URI in format postgresql://user:password@host:port/database
19+
embedding_model: Embedding model to use
20+
"""
21+
# self.db_uri = db_uri
22+
self.embedding_model = OpenAIEmbeddings(
23+
model=config.model_name,
24+
openai_api_base=config.api_base_url,
25+
api_key=config.api_key
26+
)
27+
# self.engine = self._create_engine()
28+
self.engine = engine
29+
self._setup_vector_extension()
30+
31+
32+
def _setup_vector_extension(self):
33+
"""Set up PgVector extension using SQLModel"""
34+
try:
35+
with Session(self.engine) as session:
36+
# Enable vector extension
37+
session.exec(text("CREATE EXTENSION IF NOT EXISTS vector;"))
38+
session.commit()
39+
40+
# Create vector type if needed
41+
session.exec(text("""
42+
DO $$
43+
BEGIN
44+
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'vector') THEN
45+
CREATE TYPE vector AS (x float8[]);
46+
END IF;
47+
END
48+
$$;
49+
"""))
50+
session.commit()
51+
logging.info("Vector extension setup completed successfully")
52+
except Exception as e:
53+
logging.error(f"Failed to setup vector extension: {str(e)}")
54+
raise
55+
56+
def get_session(self) -> Session:
57+
"""Get a new database session"""
58+
return Session(self.engine)
59+
60+
def extract_schema_info(self, schema_name: str = "public") -> List[Dict]:
61+
"""
62+
Extract database schema information and structure it
63+
64+
Args:
65+
schema_name: Schema name to extract, defaults to public
66+
67+
Returns:
68+
List of dictionaries containing schema information
69+
"""
70+
metadata = SQLModel.metadata
71+
metadata.reflect(bind=self.engine, schema=schema_name)
72+
73+
schema_info = []
74+
inspector = inspect(self.engine)
75+
76+
for table_name in inspector.get_table_names(schema=schema_name):
77+
table = metadata.tables[f"{schema_name}.{table_name}"]
78+
table_comment = inspector.get_table_comment(table_name, schema=schema_name)
79+
80+
# Table-level information
81+
table_info = {
82+
"type": "table",
83+
"schema": schema_name,
84+
"name": table_name,
85+
"description": table_comment.get('text') if table_comment else f"Table {table_name} in schema {schema_name}",
86+
"columns": []
87+
}
88+
89+
# Column-level information
90+
for column in inspector.get_columns(table_name, schema=schema_name):
91+
col_name = column['name']
92+
# Get column comment directly from column info
93+
col_comment = column.get('comment', f"Column {col_name} in table {table_name}")
94+
95+
column_info = {
96+
"type": "column",
97+
"schema": schema_name,
98+
"table": table_name,
99+
"name": col_name,
100+
"data_type": str(column['type']),
101+
"description": col_comment,
102+
"is_primary_key": col_name in [pk['name'] for pk in inspector.get_pk_constraint(table_name, schema=schema_name).get('constrained_columns', [])],
103+
"foreign_key": None # Will be updated below
104+
}
105+
106+
# Get foreign key information
107+
for fk in inspector.get_foreign_keys(table_name, schema=schema_name):
108+
if col_name in fk['constrained_columns']:
109+
column_info["foreign_key"] = f"{fk['referred_table']}.{fk['referred_columns'][0]}"
110+
break
111+
112+
table_info["columns"].append(column_info)
113+
114+
schema_info.append(table_info)
115+
116+
return schema_info
117+
118+
def generate_schema_documents(self, schema_info: List[Dict]) -> List[Document]:
119+
"""
120+
Convert schema information into LangChain Document objects
121+
122+
Args:
123+
schema_info: Schema information from extract_schema_info
124+
125+
Returns:
126+
List of LangChain Document objects
127+
"""
128+
documents = []
129+
130+
for table in schema_info:
131+
# Table-level document
132+
table_content = f"Table {table['schema']}.{table['name']}: {table['description']}. "
133+
table_content += f"Contains columns: {', '.join(col['name'] for col in table['columns'])}"
134+
135+
table_doc = Document(
136+
page_content=table_content,
137+
metadata={
138+
"type": "table",
139+
"schema": table["schema"],
140+
"name": table["name"],
141+
"description": table["description"],
142+
"full_name": f"{table['schema']}.{table['name']}"
143+
}
144+
)
145+
documents.append(table_doc)
146+
147+
# Column-level documents
148+
for column in table["columns"]:
149+
column_content = f"Column {table['schema']}.{table['name']}.{column['name']}: {column['description']}. "
150+
column_content += f"Data type: {column['data_type']}. "
151+
if column["is_primary_key"]:
152+
column_content += "This is a primary key. "
153+
if column["foreign_key"]:
154+
column_content += f"Foreign key to {column['foreign_key']}."
155+
156+
column_doc = Document(
157+
page_content=column_content,
158+
metadata={
159+
"type": "column",
160+
"schema": table["schema"],
161+
"table": table["name"],
162+
"name": column["name"],
163+
"description": column["description"],
164+
"data_type": column["data_type"],
165+
"full_name": f"{table['schema']}.{table['name']}.{column['name']}"
166+
}
167+
)
168+
documents.append(column_doc)
169+
170+
return documents
171+
172+
def store_embeddings(self, documents: List[Document], collection_name: str = "schema_embeddings"):
173+
"""Store schema document embeddings in PgVector
174+
175+
Args:
176+
documents: Documents from generate_schema_documents
177+
collection_name: Vector collection name
178+
"""
179+
try:
180+
PGVector.from_documents(
181+
embedding=self.embedding_model,
182+
documents=documents,
183+
collection_name=collection_name,
184+
connection_string=self.db_uri,
185+
pre_delete_collection=True # Optionally recreate collection
186+
)
187+
logging.info(f"Successfully stored {len(documents)} embeddings in collection {collection_name}")
188+
except Exception as e:
189+
logging.error(f"Failed to store embeddings: {str(e)}")
190+
raise
191+
192+
def get_relevant_schema(
193+
self,
194+
query: str,
195+
collection_name: str = "schema_embeddings",
196+
top_k: int = 5
197+
) -> Tuple[List[Dict], List[Dict]]:
198+
"""Get relevant schema information based on user query
199+
200+
Args:
201+
query: User query text
202+
collection_name: Vector collection name
203+
top_k: Number of most relevant items to return
204+
205+
Returns:
206+
Tuple containing (relevant_tables_list, relevant_columns_list)
207+
"""
208+
try:
209+
store = PGVector(
210+
collection_name=collection_name,
211+
connection_string=self.db_uri,
212+
embedding_function=self.embedding_model,
213+
)
214+
215+
# Execute similarity search
216+
docs = store.similarity_search_with_score(query, k=top_k)
217+
218+
relevant_tables = []
219+
relevant_columns = []
220+
seen_tables = set()
221+
222+
for doc, score in docs:
223+
metadata = doc.metadata
224+
item = {
225+
"name": metadata["full_name"],
226+
"description": metadata["description"],
227+
"type": metadata["type"],
228+
"score": float(score)
229+
}
230+
231+
if metadata["type"] == "table" and metadata["full_name"] not in seen_tables:
232+
relevant_tables.append(item)
233+
seen_tables.add(metadata["full_name"])
234+
elif metadata["type"] == "column":
235+
relevant_columns.append(item)
236+
# Ensure associated table is also included
237+
table_name = f"{metadata['schema']}.{metadata['table']}"
238+
if table_name not in seen_tables:
239+
relevant_tables.append({
240+
"name": table_name,
241+
"description": f"Table containing column {metadata['name']}",
242+
"type": "table",
243+
"score": float(score) * 0.9 # Slightly lower table score
244+
})
245+
seen_tables.add(table_name)
246+
247+
return relevant_tables, relevant_columns
248+
249+
except Exception as e:
250+
logging.error(f"Error during schema search: {str(e)}")
251+
raise
252+
253+
def generate_schema_context(self, relevant_tables: List[Dict], relevant_columns: List[Dict]) -> str:
254+
"""
255+
Generate concise schema context for LLM usage
256+
257+
Args:
258+
relevant_tables: List of relevant tables
259+
relevant_columns: List of relevant columns
260+
261+
Returns:
262+
Formatted schema context string
263+
"""
264+
context = "## Relevant Database Schema Information:\n\n"
265+
266+
if relevant_tables:
267+
context += "### Tables:\n"
268+
for table in sorted(relevant_tables, key=lambda x: -x["score"]):
269+
context += f"- {table['name']}: {table['description']} (relevance score: {table['score']:.2f})\n"
270+
271+
if relevant_columns:
272+
context += "\n### Columns:\n"
273+
for column in sorted(relevant_columns, key=lambda x: -x["score"]):
274+
context += f"- {column['name']}: {column['description']} (relevance score: {column['score']:.2f})\n"
275+
276+
return context

backend/apps/chat/schemas/llm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from langgraph.prebuilt import create_react_agent
33
from langchain_core.prompts import ChatPromptTemplate
44
from apps.chat.schemas.chat_base_schema import LLMConfig, LLMFactory
5+
from apps.chat.schemas.schema_engine import SchemaEngine
56
from apps.datasource.models.datasource import CoreDatasource
67
from apps.db.db import exec_sql, get_uri
78
from common.core.config import settings
@@ -103,9 +104,13 @@ def generate_sql(self, question: str) -> str:
103104
async def async_generate(self, question: str) -> AsyncGenerator[str, None]:
104105

105106
chain = self.prompt | self.agent_executor
106-
schema = self.db.get_table_info()
107+
# schema = self.db.get_table_info()
108+
109+
schema_engine = SchemaEngine(engine=self.db._engine)
110+
mschema = schema_engine.mschema
111+
mschema_str = mschema.to_mschema()
107112

108-
async for chunk in chain.astream({"schema": schema, "question": question}):
113+
async for chunk in chain.astream({"schema": mschema_str, "question": question}):
109114
if not isinstance(chunk, dict):
110115
continue
111116

0 commit comments

Comments
 (0)