1- import os
2- from typing import Union
1+ """
2+ Lang2SQL Streamlit 애플리케이션.
3+
4+ 자연어로 입력된 질문을 SQL 쿼리로 변환하고,
5+ ClickHouse 데이터베이스에 실행한 결과를 출력합니다.
6+ """
37
4- import pandas as pd
58import streamlit as st
6- from clickhouse_driver import Client
7- from dotenv import load_dotenv
89from langchain .chains .sql_database .prompt import SQL_PROMPTS
910from langchain_core .messages import HumanMessage
1011
1112from llm_utils .connect_db import ConnectDB
1213from llm_utils .graph import builder
1314
14- # Clickhouse 연결
15- db = ConnectDB ()
16- db .connect_to_clickhouse ()
15+ DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
16+ SIDEBAR_OPTIONS = {
17+ "show_total_token_usage" : "Show Total Token Usage" ,
18+ "show_result_description" : "Show Result Description" ,
19+ "show_sql" : "Show SQL" ,
20+ "show_question_reinterpreted_by_ai" : "Show User Question Reinterpreted by AI" ,
21+ "show_referenced_tables" : "Show List of Referenced Tables" ,
22+ "show_table" : "Show Table" ,
23+ "show_chart" : "Show Chart" ,
24+ }
25+ def summarize_total_tokens (data : list ) -> int :
26+ """
27+ 메시지 데이터에서 총 토큰 사용량을 집계합니다.
1728
18- # Streamlit 앱 제목
19- st . title ( "Lang2SQL" )
29+ Args:
30+ data (list): usage_metadata를 포함하는 객체들의 리스트.
2031
21- # 사용자 입력 받기
22- user_query = st .text_area (
23- "쿼리를 입력하세요:" ,
24- value = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" ,
25- )
32+ Returns:
33+ int: 총 토큰 사용량 합계.
34+ """
2635
27- user_database_env = st .selectbox (
28- "db 환경정보를 입력하세요:" ,
29- options = SQL_PROMPTS .keys (),
30- index = 0 ,
31- )
32- st .sidebar .title ("Output Settings" )
33- st .sidebar .checkbox ("Show Total Token Usage" , value = True , key = "show_total_token_usage" )
34- st .sidebar .checkbox (
35- "Show Result Description" , value = True , key = "show_result_description"
36- )
37- st .sidebar .checkbox ("Show SQL" , value = True , key = "show_sql" )
38- st .sidebar .checkbox (
39- "Show User Question Reinterpreted by AI" ,
40- value = True ,
41- key = "show_question_reinterpreted_by_ai" ,
42- )
43- st .sidebar .checkbox (
44- "Show List of Referenced Tables" , value = True , key = "show_referenced_tables"
45- )
46- st .sidebar .checkbox ("Show Table" , value = True , key = "show_table" )
47- st .sidebar .checkbox ("Show Chart" , value = True , key = "show_chart" )
48-
49-
50- # Token usage 집계 함수 정의
51- def summarize_total_tokens (data ):
5236 total_tokens = 0
5337 for item in data :
5438 token_usage = getattr (item , "usage_metadata" , {})
5539 total_tokens += token_usage .get ("total_tokens" , 0 )
5640 return total_tokens
5741
5842
59- # 버튼 클릭 시 실행
60- if st .button ("쿼리 실행" ):
61- # 그래프 컴파일 및 쿼리 실행
62- graph = builder .compile ()
43+ def execute_query (
44+ * ,
45+ query : str ,
46+ database_env : str ,
47+ ) -> dict :
48+ """
49+ Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.
50+
51+ Args:
52+ query (str): 자연어로 작성된 사용자 쿼리.
53+ database_env (str): 사용할 데이터베이스 환경 설정 이름.
54+
55+ Returns:
56+ dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리.
57+ """
58+ # 세션 상태에서 그래프 가져오기
59+ graph = st .session_state .get ("graph" )
60+ if graph is None :
61+ graph = builder .compile ()
62+ st .session_state ["graph" ] = graph
6363
6464 res = graph .invoke (
6565 input = {
66- "messages" : [HumanMessage (content = user_query )],
67- "user_database_env" : user_database_env ,
66+ "messages" : [HumanMessage (content = query )],
67+ "user_database_env" : database_env ,
6868 "best_practice_query" : "" ,
6969 }
7070 )
71+
72+ return res
73+
74+
75+ def display_result (
76+ * ,
77+ res : dict ,
78+ database : ConnectDB ,
79+ ) -> None :
80+ """
81+ Lang2SQL 실행 결과를 Streamlit 화면에 출력합니다.
82+
83+ Args:
84+ res (dict): Lang2SQL 실행 결과 딕셔너리.
85+ database (ConnectDB): SQL 쿼리 실행을 위한 데이터베이스 연결 객체.
86+
87+ 출력 항목:
88+ - 총 토큰 사용량
89+ - 생성된 SQL 쿼리
90+ - 결과 설명
91+ - AI가 재해석한 사용자 질문
92+ - 참조된 테이블 목록
93+ - 쿼리 실행 결과 테이블
94+ """
7195 total_tokens = summarize_total_tokens (res ["messages" ])
7296
73- # 결과 출력
7497 if st .session_state .get ("show_total_token_usage" , True ):
7598 st .write ("총 토큰 사용량:" , total_tokens )
7699 if st .session_state .get ("show_sql" , True ):
77- st .write ("결과:" , "\n \n ```sql\n " + res ["generated_query" ] + "\n ```" )
100+ st .write ("결과:" , "\n \n ```sql\n " + res ["generated_query" ]. content + "\n ```" )
78101 if st .session_state .get ("show_result_description" , True ):
79102 st .write ("결과 설명:\n \n " , res ["messages" ][- 1 ].content )
80103 if st .session_state .get ("show_question_reinterpreted_by_ai" , True ):
@@ -83,8 +106,42 @@ def summarize_total_tokens(data):
83106 st .write ("참고한 테이블 목록:" , res ["searched_tables" ])
84107 if st .session_state .get ("show_table" , True ):
85108 sql = res ["generated_query" ]
86- df = db .run_sql (sql )
87- if len (df ) > 10 :
88- st .dataframe (df .head (10 ))
89- else :
90- st .dataframe (df )
109+ df = database .run_sql (sql )
110+ st .dataframe (df .head (10 ) if len (df ) > 10 else df )
111+
112+
113+ db = ConnectDB ()
114+ db .connect_to_clickhouse ()
115+
116+ st .title ("Lang2SQL" )
117+
118+ # 세션 상태 초기화
119+ if "graph" not in st .session_state :
120+ st .session_state ["graph" ] = builder .compile ()
121+ st .info ("Lang2SQL이 성공적으로 시작되었습니다." )
122+
123+ # 새로고침 버튼 추가
124+ if st .sidebar .button ("Lang2SQL 새로고침" ):
125+ st .session_state ["graph" ] = builder .compile ()
126+ st .sidebar .success ("Lang2SQL이 성공적으로 새로고침되었습니다." )
127+
128+ user_query = st .text_area (
129+ "쿼리를 입력하세요:" ,
130+ value = DEFAULT_QUERY ,
131+ )
132+ user_database_env = st .selectbox (
133+ "DB 환경정보를 입력하세요:" ,
134+ options = SQL_PROMPTS .keys (),
135+ index = 0 ,
136+ )
137+
138+ st .sidebar .title ("Output Settings" )
139+ for key , label in SIDEBAR_OPTIONS .items ():
140+ st .sidebar .checkbox (label , value = True , key = key )
141+
142+ if st .button ("쿼리 실행" ):
143+ result = execute_query (
144+ query = user_query ,
145+ database_env = user_database_env ,
146+ )
147+ display_result (res = result , database = db )
0 commit comments