22from langchain_core .messages import HumanMessage
33from llm_utils .graph import builder
44from langchain .chains .sql_database .prompt import SQL_PROMPTS
5+ import os
6+ from typing import Union
7+ import pandas as pd
8+
9+ from clickhouse_driver import Client
10+ from llm_utils .connect_db import ConnectDB
11+ from dotenv import load_dotenv
12+
13+
14+ # Clickhouse 연결
15+ db = ConnectDB ()
16+ db .connect_to_clickhouse ()
517
618# Streamlit 앱 제목
719st .title ("Lang2SQL" )
1729 options = SQL_PROMPTS .keys (),
1830 index = 0 ,
1931)
32+ st .sidebar .title ("Output Settings" )
33+ st .sidebar .checkbox ("Show Total Token Usage" , value = True , key = "show_total_token_usage" )
34+ st .sidebar .checkbox (
35+ "Show Result Description" , value = True , key = "show_result_description"
36+ )
37+ st .sidebar .checkbox ("Show SQL" , value = True , key = "show_sql" )
38+ st .sidebar .checkbox (
39+ "Show User Question Reinterpreted by AI" ,
40+ value = True ,
41+ key = "show_question_reinterpreted_by_ai" ,
42+ )
43+ st .sidebar .checkbox (
44+ "Show List of Referenced Tables" , value = True , key = "show_referenced_tables"
45+ )
46+ st .sidebar .checkbox ("Show Table" , value = True , key = "show_table" )
47+ st .sidebar .checkbox ("Show Chart" , value = True , key = "show_chart" )
2048
2149
2250# Token usage 집계 함수 정의
@@ -43,9 +71,20 @@ def summarize_total_tokens(data):
4371 total_tokens = summarize_total_tokens (res ["messages" ])
4472
4573 # 결과 출력
46- st .write ("총 토큰 사용량:" , total_tokens )
47- # st.write("결과:", res["generated_query"].content)
48- st .write ("결과:" , "\n \n ```sql\n " + res ["generated_query" ] + "\n ```" )
49- st .write ("결과 설명:\n \n " , res ["messages" ][- 1 ].content )
50- st .write ("AI가 재해석한 사용자 질문:\n " , res ["refined_input" ].content )
51- st .write ("참고한 테이블 목록:" , res ["searched_tables" ])
74+ if st .session_state .get ("show_total_token_usage" , True ):
75+ st .write ("총 토큰 사용량:" , total_tokens )
76+ if st .session_state .get ("show_sql" , True ):
77+ st .write ("결과:" , "\n \n ```sql\n " + res ["generated_query" ] + "\n ```" )
78+ if st .session_state .get ("show_result_description" , True ):
79+ st .write ("결과 설명:\n \n " , res ["messages" ][- 1 ].content )
80+ if st .session_state .get ("show_question_reinterpreted_by_ai" , True ):
81+ st .write ("AI가 재해석한 사용자 질문:\n " , res ["refined_input" ].content )
82+ if st .session_state .get ("show_referenced_tables" , True ):
83+ st .write ("참고한 테이블 목록:" , res ["searched_tables" ])
84+ if st .session_state .get ("show_table" , True ):
85+ sql = res ["generated_query" ]
86+ df = db .run_sql (sql )
87+ if len (df ) > 10 :
88+ st .dataframe (df .head (10 ))
89+ else :
90+ st .dataframe (df )
0 commit comments