1+ """
2+ Lang2SQL Streamlit 애플리케이션.
3+
4+ 자연어로 입력된 질문을 SQL 쿼리로 변환하고,
5+ ClickHouse 데이터베이스에 실행한 결과를 출력합니다.
6+ """
7+
18import streamlit as st
2- from langchain_core .messages import HumanMessage
3- from llm_utils .graph import builder
49from langchain .chains .sql_database .prompt import SQL_PROMPTS
5- import os
6- from typing import Union
7- import pandas as pd
10+ from langchain_core .messages import HumanMessage
811
9- from clickhouse_driver import Client
1012from llm_utils .connect_db import ConnectDB
11- from dotenv import load_dotenv
13+ from llm_utils . graph import builder
1214
15+ DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
16+ SIDEBAR_OPTIONS = {
17+ "show_total_token_usage" : "Show Total Token Usage" ,
18+ "show_result_description" : "Show Result Description" ,
19+ "show_sql" : "Show SQL" ,
20+ "show_question_reinterpreted_by_ai" : "Show User Question Reinterpreted by AI" ,
21+ "show_referenced_tables" : "Show List of Referenced Tables" ,
22+ "show_table" : "Show Table" ,
23+ "show_chart" : "Show Chart" ,
24+ }
1325
14- # Clickhouse 연결
15- db = ConnectDB ()
16- db .connect_to_clickhouse ()
1726
18- # Streamlit 앱 제목
19- st .title ("Lang2SQL" )
27+ def summarize_total_tokens (data : list ) -> int :
28+ """
29+ 메시지 데이터에서 총 토큰 사용량을 집계합니다.
2030
21- # 사용자 입력 받기
22- user_query = st .text_area (
23- "쿼리를 입력하세요:" ,
24- value = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" ,
25- )
26-
27- user_database_env = st .selectbox (
28- "db 환경정보를 입력하세요:" ,
29- options = SQL_PROMPTS .keys (),
30- index = 0 ,
31- )
32- st .sidebar .title ("Output Settings" )
33- st .sidebar .checkbox ("Show Total Token Usage" , value = True , key = "show_total_token_usage" )
34- st .sidebar .checkbox (
35- "Show Result Description" , value = True , key = "show_result_description"
36- )
37- st .sidebar .checkbox ("Show SQL" , value = True , key = "show_sql" )
38- st .sidebar .checkbox (
39- "Show User Question Reinterpreted by AI" ,
40- value = True ,
41- key = "show_question_reinterpreted_by_ai" ,
42- )
43- st .sidebar .checkbox (
44- "Show List of Referenced Tables" , value = True , key = "show_referenced_tables"
45- )
46- st .sidebar .checkbox ("Show Table" , value = True , key = "show_table" )
47- st .sidebar .checkbox ("Show Chart" , value = True , key = "show_chart" )
31+ Args:
32+ data (list): usage_metadata를 포함하는 객체들의 리스트.
4833
34+ Returns:
35+ int: 총 토큰 사용량 합계.
36+ """
4937
50- # Token usage 집계 함수 정의
51- def summarize_total_tokens (data ):
5238 total_tokens = 0
5339 for item in data :
5440 token_usage = getattr (item , "usage_metadata" , {})
5541 total_tokens += token_usage .get ("total_tokens" , 0 )
5642 return total_tokens
5743
5844
59- # 버튼 클릭 시 실행
60- if st .button ("쿼리 실행" ):
61- # 그래프 컴파일 및 쿼리 실행
62- graph = builder .compile ()
45+ def execute_query (
46+ * ,
47+ query : str ,
48+ database_env : str ,
49+ ) -> dict :
50+ """
51+ Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.
52+
53+ Args:
54+ query (str): 자연어로 작성된 사용자 쿼리.
55+ database_env (str): 사용할 데이터베이스 환경 설정 이름.
56+
57+ Returns:
58+ dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리.
59+ """
6360
61+ graph = builder .compile ()
6462 res = graph .invoke (
6563 input = {
66- "messages" : [HumanMessage (content = user_query )],
67- "user_database_env" : user_database_env ,
64+ "messages" : [HumanMessage (content = query )],
65+ "user_database_env" : database_env ,
6866 "best_practice_query" : "" ,
6967 }
7068 )
69+
70+ return res
71+
72+
73+ def display_result (
74+ * ,
75+ res : dict ,
76+ database : ConnectDB ,
77+ ) -> None :
78+ """
79+ Lang2SQL 실행 결과를 Streamlit 화면에 출력합니다.
80+
81+ Args:
82+ res (dict): Lang2SQL 실행 결과 딕셔너리.
83+ database (ConnectDB): SQL 쿼리 실행을 위한 데이터베이스 연결 객체.
84+
85+ 출력 항목:
86+ - 총 토큰 사용량
87+ - 생성된 SQL 쿼리
88+ - 결과 설명
89+ - AI가 재해석한 사용자 질문
90+ - 참조된 테이블 목록
91+ - 쿼리 실행 결과 테이블
92+ """
7193 total_tokens = summarize_total_tokens (res ["messages" ])
7294
73- # 결과 출력
7495 if st .session_state .get ("show_total_token_usage" , True ):
7596 st .write ("총 토큰 사용량:" , total_tokens )
7697 if st .session_state .get ("show_sql" , True ):
77- st .write ("결과:" , "\n \n ```sql\n " + res ["generated_query" ] + "\n ```" )
98+ st .write ("결과:" , "\n \n ```sql\n " + res ["generated_query" ]. content + "\n ```" )
7899 if st .session_state .get ("show_result_description" , True ):
79100 st .write ("결과 설명:\n \n " , res ["messages" ][- 1 ].content )
80101 if st .session_state .get ("show_question_reinterpreted_by_ai" , True ):
@@ -83,8 +104,32 @@ def summarize_total_tokens(data):
83104 st .write ("참고한 테이블 목록:" , res ["searched_tables" ])
84105 if st .session_state .get ("show_table" , True ):
85106 sql = res ["generated_query" ]
86- df = db .run_sql (sql )
87- if len (df ) > 10 :
88- st .dataframe (df .head (10 ))
89- else :
90- st .dataframe (df )
107+ df = database .run_sql (sql )
108+ st .dataframe (df .head (10 ) if len (df ) > 10 else df )
109+
110+
111+ db = ConnectDB ()
112+ db .connect_to_clickhouse ()
113+
114+ st .title ("Lang2SQL" )
115+
116+ user_query = st .text_area (
117+ "쿼리를 입력하세요:" ,
118+ value = DEFAULT_QUERY ,
119+ )
120+ user_database_env = st .selectbox (
121+ "DB 환경정보를 입력하세요:" ,
122+ options = SQL_PROMPTS .keys (),
123+ index = 0 ,
124+ )
125+
126+ st .sidebar .title ("Output Settings" )
127+ for key , label in SIDEBAR_OPTIONS .items ():
128+ st .sidebar .checkbox (label , value = True , key = key )
129+
130+ if st .button ("쿼리 실행" ):
131+ result = execute_query (
132+ query = user_query ,
133+ database_env = user_database_env ,
134+ )
135+ display_result (res = result , database = db )
0 commit comments