Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions interface/lang2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@
from langchain_core.messages import HumanMessage
from llm_utils.graph import builder
from langchain.chains.sql_database.prompt import SQL_PROMPTS
import os
from typing import Union
import pandas as pd

from clickhouse_driver import Client
from llm_utils.connect_db import ConnectDB
from dotenv import load_dotenv


# Clickhouse 연결
db = ConnectDB()
db.connect_to_clickhouse()

# Streamlit 앱 제목
st.title("Lang2SQL")
Expand All @@ -17,6 +29,22 @@
options=SQL_PROMPTS.keys(),
index=0,
)
st.sidebar.title("Output Settings")
st.sidebar.checkbox("Show Total Token Usage", value=True, key="show_total_token_usage")
st.sidebar.checkbox(
"Show Result Description", value=True, key="show_result_description"
)
st.sidebar.checkbox("Show SQL", value=True, key="show_sql")
st.sidebar.checkbox(
"Show User Question Reinterpreted by AI",
value=True,
key="show_question_reinterpreted_by_ai",
)
st.sidebar.checkbox(
"Show List of Referenced Tables", value=True, key="show_referenced_tables"
)
st.sidebar.checkbox("Show Table", value=True, key="show_table")
st.sidebar.checkbox("Show Chart", value=True, key="show_chart")


# Token usage 집계 함수 정의
Expand All @@ -43,9 +71,20 @@ def summarize_total_tokens(data):
total_tokens = summarize_total_tokens(res["messages"])

# 결과 출력
st.write("총 토큰 사용량:", total_tokens)
# st.write("결과:", res["generated_query"].content)
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
st.write("결과 설명:\n\n", res["messages"][-1].content)
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
st.write("참고한 테이블 목록:", res["searched_tables"])
if st.session_state.get("show_total_token_usage", True):
st.write("총 토큰 사용량:", total_tokens)
if st.session_state.get("show_sql", True):
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
if st.session_state.get("show_result_description", True):
st.write("결과 설명:\n\n", res["messages"][-1].content)
if st.session_state.get("show_question_reinterpreted_by_ai", True):
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
if st.session_state.get("show_referenced_tables", True):
st.write("참고한 테이블 목록:", res["searched_tables"])
if st.session_state.get("show_table", True):
sql = res["generated_query"]
df = db.run_sql(sql)
if len(df) > 10:
st.dataframe(df.head(10))
else:
st.dataframe(df)
1 change: 0 additions & 1 deletion interface/streamlit_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import streamlit as st


pg = st.navigation(
[
st.Page("lang2sql.py", title="Lang2SQL"),
Expand Down
6 changes: 1 addition & 5 deletions llm_utils/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@
else:
print(f"⚠️ 환경변수 파일(.env)이 {os.getcwd()}에 없습니다!")

llm = get_llm(
model_type="openai",
model_name="gpt-4o-mini",
openai_api_key=os.getenv("OPENAI_API_KEY"),
)
llm = get_llm()


def create_query_refiner_chain(llm):
Expand Down
44 changes: 44 additions & 0 deletions llm_utils/connect_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
from typing import Union
import pandas as pd
from clickhouse_driver import Client
from dotenv import load_dotenv

# 환경변수
load_dotenv()


class ConnectDB:
def __init__(self):
self.client = None
self.host = os.getenv("CLICKHOUSE_HOST")
self.dbname = os.getenv("CLICKHOUSE_DATABASE")
self.user = os.getenv("CLICKHOUSE_USER")
self.password = os.getenv("CLICKHOUSE_PASSWORD")
self.port = os.getenv("CLICKHOUSE_PORT")

def connect_to_clickhouse(self):

# ClickHouse 서버 정보
self.client = Client(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.dbname, # 예: '127.0.0.1' # 기본 TCP 포트
)

def run_sql(self, sql: str) -> Union[pd.DataFrame, None]:
if self.client:
try:
result = self.client.execute(sql, with_column_types=True)
# 결과와 컬럼 정보 분리
rows, columns = result
column_names = [col[0] for col in columns]

# Create a pandas dataframe from the results
df = pd.DataFrame(rows, columns=column_names)
return df

except Exception as e:
raise e
7 changes: 1 addition & 6 deletions llm_utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def get_table_info_node(state: QueryMakerState):
documents = get_info_from_db()
db = FAISS.from_documents(documents, embeddings)
db.save_local(os.getcwd() + "/table_info_db")
print("table_info_db not found")
doc_res = db.similarity_search(state["messages"][-1].content)
documents_dict = {}

Expand Down Expand Up @@ -112,11 +111,7 @@ class SQLResult(BaseModel):

def query_maker_node_with_db_guide(state: QueryMakerState):
sql_prompt = SQL_PROMPTS[state["user_database_env"]]
llm = get_llm(
model_type="openai",
model_name="gpt-4o-mini",
openai_api_key=os.getenv("OPENAI_API_KEY"),
)
llm = get_llm()
chain = sql_prompt | llm.with_structured_output(SQLResult)
res = chain.invoke(
input={
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
"streamlit==1.41.1",
"python-dotenv==1.0.1",
"faiss-cpu==1.10.0",
"langchain-aws>=0.2.21,<0.3.0",
"langchain-google-genai>=2.1.3,<3.0.0",
"langchain-ollama>=0.3.2,<0.4.0",
"langchain-huggingface>=0.1.2,<0.2.0",
],
entry_points={
"console_scripts": [
Expand Down