Skip to content

Commit 719c8c9

Browse files
Streamlit 코드 수정
1 parent 07bf11f commit 719c8c9

File tree

1 file changed

+46
-6
lines changed

1 file changed

+46
-6
lines changed

interface/streamlit_app.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22
from langchain_core.messages import HumanMessage
33
from llm_utils.graph import builder
44
from langchain.chains.sql_database.prompt import SQL_PROMPTS
5+
import os
6+
from typing import Union
7+
import pandas as pd
8+
from clickhouse_driver import Client
9+
from connect_db import ConnectDB
10+
from dotenv import load_dotenv
11+
12+
13+
# Clickhouse 연결
14+
db = ConnectDB()
15+
db.connect_to_clickhouse()
516

617
# Streamlit 앱 제목
718
st.title("Lang2SQL")
@@ -17,6 +28,22 @@
1728
options=SQL_PROMPTS.keys(),
1829
index=0,
1930
)
31+
st.sidebar.title("Output Settings")
32+
st.sidebar.checkbox("Show Total Token Usage", value=True, key="show_total_token_usage")
33+
st.sidebar.checkbox(
34+
"Show Result Description", value=True, key="show_result_description"
35+
)
36+
st.sidebar.checkbox("Show SQL", value=True, key="show_sql")
37+
st.sidebar.checkbox(
38+
"Show User Question Reinterpreted by AI",
39+
value=True,
40+
key="show_question_reinterpreted_by_ai",
41+
)
42+
st.sidebar.checkbox(
43+
"Show List of Referenced Tables", value=True, key="show_referenced_tables"
44+
)
45+
st.sidebar.checkbox("Show Table", value=True, key="show_table")
46+
st.sidebar.checkbox("Show Chart", value=True, key="show_chart")
2047

2148

2249
# Token usage 집계 함수 정의
@@ -43,9 +70,22 @@ def summarize_total_tokens(data):
4370
total_tokens = summarize_total_tokens(res["messages"])
4471

4572
# 결과 출력
46-
st.write("총 토큰 사용량:", total_tokens)
47-
# st.write("결과:", res["generated_query"].content)
48-
st.write("결과:", "\n\n```sql\n" + res["generated_query"] + "\n```")
49-
st.write("결과 설명:\n\n", res["messages"][-1].content)
50-
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
51-
st.write("참고한 테이블 목록:", res["searched_tables"])
73+
if st.session_state.get("show_total_token_usage", True):
74+
st.write("총 토큰 사용량:", total_tokens)
75+
if st.session_state.get("show_sql", True):
76+
st.write("결과:", res["generated_query"].content)
77+
if st.session_state.get("show_result_description", True):
78+
st.write("결과 설명:\n\n", res["messages"][-1].content)
79+
if st.session_state.get("show_question_reinterpreted_by_ai", True):
80+
st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content)
81+
if st.session_state.get("show_referenced_tables", True):
82+
st.write("참고한 테이블 목록:", res["searched_tables"])
83+
if st.session_state.get("show_table", True):
84+
sql = res["generated_query"].content.split("```")[1][
85+
3:
86+
] # 쿼리 앞쪽의 "sql " 제거
87+
df = db.run_sql(sql)
88+
if len(df) > 10:
89+
st.dataframe(df.head(10))
90+
else:
91+
st.dataframe(df)

0 commit comments

Comments
 (0)