Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions interface/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@
from langchain_core.messages import HumanMessage
from llm_utils.graph import builder
from langchain.chains.sql_database.prompt import SQL_PROMPTS
import os
from typing import Union
import pandas as pd
from clickhouse_driver import Client
from connect_db import ConnectDB
from dotenv import load_dotenv


# Clickhouse 연결
db = ConnectDB()
db.connect_to_clickhouse()

# Streamlit 앱 제목
st.title("Lang2SQL")
Expand All @@ -17,6 +28,22 @@
options=SQL_PROMPTS.keys(),
index=0,
)
st.sidebar.title("Output Settings")
st.sidebar.checkbox("Show Total Token Usage", value=True, key="show_total_token_usage")
st.sidebar.checkbox(
"Show Result Description", value=True, key="show_result_description"
)
st.sidebar.checkbox("Show SQL", value=True, key="show_sql")
st.sidebar.checkbox(
"Show User Question Reinterpreted by AI",
value=True,
key="show_question_reinterpreted_by_ai",
)
st.sidebar.checkbox(
"Show List of Referenced Tables", value=True, key="show_referenced_tables"
)
st.sidebar.checkbox("Show Table", value=True, key="show_table")
st.sidebar.checkbox("Show Chart", value=True, key="show_chart")


# Token usage 집계 함수 정의
Expand All @@ -43,9 +70,22 @@ def summarize_total_tokens(data):
total_tokens = summarize_total_tokens(res["messages"])

# 결과 출력
st.write("총 토큰 사용량:", total_tokens)
# st.write("결과:", res["generated_query"].content)
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
st.write("결과 설명:\n\n", res["messages"][-1].content)
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
st.write("참고한 테이블 목록:", res["searched_tables"])
if st.session_state.get("show_total_token_usage", True):
st.write("총 토큰 사용량:", total_tokens)
if st.session_state.get("show_sql", True):
st.write("결과:", res["generated_query"].content)
if st.session_state.get("show_result_description", True):
st.write("결과 설명:\n\n", res["messages"][-1].content)
if st.session_state.get("show_question_reinterpreted_by_ai", True):
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
if st.session_state.get("show_referenced_tables", True):
st.write("참고한 테이블 목록:", res["searched_tables"])
if st.session_state.get("show_table", True):
sql = res["generated_query"].content.split("```")[1][
3:
] # 쿼리 앞쪽의 "sql " 제거
df = db.run_sql(sql)
if len(df) > 10:
st.dataframe(df.head(10))
else:
st.dataframe(df)
44 changes: 44 additions & 0 deletions llm_utils/connect_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
from typing import Union
import pandas as pd
from clickhouse_driver import Client
from dotenv import load_dotenv

# 환경변수
load_dotenv()


class ConnectDB:
def __init__(self):
self.client = None
self.host = os.getenv("CLICKHOUSE_HOST")
self.dbname = os.getenv("CLICKHOUSE_DATABASE")
self.user = os.getenv("CLICKHOUSE_USER")
self.password = os.getenv("CLICKHOUSE_PASSWORD")
self.port = os.getenv("CLICKHOUSE_PORT")

def connect_to_clickhouse(self):

# ClickHouse 서버 정보
self.client = Client(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.dbname, # 예: '127.0.0.1' # 기본 TCP 포트
)

def run_sql(self, sql: str) -> Union[pd.DataFrame, None]:
if self.client:
try:
result = self.client.execute(sql, with_column_types=True)
# 결과와 컬럼 정보 분리
rows, columns = result
column_names = [col[0] for col in columns]

# Create a pandas dataframe from the results
df = pd.DataFrame(rows, columns=column_names)
return df

except Exception as e:
raise e