77
88import streamlit as st
99from langchain .chains .sql_database .prompt import SQL_PROMPTS
10- from langchain_core .messages import AIMessage , HumanMessage
10+ from langchain_core .messages import AIMessage
1111
1212from llm_utils .connect_db import ConnectDB
1313from llm_utils .display_chart import DisplayChart
14- from llm_utils .graph_utils .enriched_graph import builder as enriched_builder
15- from llm_utils .graph_utils .basic_graph import builder
14+ from llm_utils .query_executor import execute_query as execute_query_common
1615from llm_utils .llm_response_parser import LLMResponseParser
1716from llm_utils .token_utils import TokenUtils
17+ from llm_utils .graph_utils .enriched_graph import builder as enriched_builder
18+ from llm_utils .graph_utils .basic_graph import builder
1819
1920TITLE = "Lang2SQL"
2021DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
@@ -40,9 +41,8 @@ def execute_query(
4041 """
4142 자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 Lang2SQL 그래프 인터페이스 함수입니다.
4243
43- 이 함수는 Lang2SQL 파이프라인(graph)을 세션 상태에서 가져오거나 새로 컴파일한 뒤,
44- 사용자의 자연어 질문을 SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다.
45- 내부적으로 LangChain의 `graph.invoke` 메서드를 호출합니다.
44+ 이 함수는 공용 execute_query 함수를 호출하여 Lang2SQL 파이프라인을 실행합니다.
45+ Streamlit 세션 상태를 활용하여 그래프를 재사용합니다.
4646
4747 Args:
4848 query (str): 사용자가 입력한 자연어 기반 질문.
@@ -59,27 +59,16 @@ def execute_query(
5959 - "searched_tables": 참조된 테이블 목록 등 추가 정보
6060 """
6161
62- graph = st .session_state .get ("graph" )
63- if graph is None :
64- graph_builder = (
65- enriched_builder if st .session_state .get ("use_enriched" ) else builder
66- )
67- graph = graph_builder .compile ()
68- st .session_state ["graph" ] = graph
69-
70- res = graph .invoke (
71- input = {
72- "messages" : [HumanMessage (content = query )],
73- "user_database_env" : database_env ,
74- "best_practice_query" : "" ,
75- "retriever_name" : retriever_name ,
76- "top_n" : top_n ,
77- "device" : device ,
78- }
62+ return execute_query_common (
63+ query = query ,
64+ database_env = database_env ,
65+ retriever_name = retriever_name ,
66+ top_n = top_n ,
67+ device = device ,
68+ use_enriched_graph = st .session_state .get ("use_enriched" , False ),
69+ session_state = st .session_state ,
7970 )
8071
81- return res
82-
8372
8473def display_result (
8574 * ,
@@ -120,40 +109,50 @@ def should_show(_key: str) -> bool:
120109 if should_show ("show_sql" ):
121110 st .markdown ("---" )
122111 generated_query = res .get ("generated_query" )
123- query_text = (
124- generated_query .content
125- if isinstance (generated_query , AIMessage )
126- else str (generated_query )
127- )
112+ if generated_query :
113+ query_text = (
114+ generated_query .content
115+ if isinstance (generated_query , AIMessage )
116+ else str (generated_query )
117+ )
128118
129- try :
130- sql = LLMResponseParser .extract_sql (query_text )
131- st .markdown ("**생성된 SQL 쿼리:**" )
132- st .code (sql , language = "sql" )
133- except ValueError :
134- st .warning ("SQL 블록을 추출할 수 없습니다." )
135- st .text (query_text )
136-
137- interpretation = LLMResponseParser .extract_interpretation (query_text )
138- if interpretation :
139- st .markdown ("**결과 해석:**" )
140- st .code (interpretation )
119+ # query_text가 문자열인지 확인
120+ if isinstance (query_text , str ):
121+ try :
122+ sql = LLMResponseParser .extract_sql (query_text )
123+ st .markdown ("**생성된 SQL 쿼리:**" )
124+ st .code (sql , language = "sql" )
125+ except ValueError :
126+ st .warning ("SQL 블록을 추출할 수 없습니다." )
127+ st .text (query_text )
128+
129+ interpretation = LLMResponseParser .extract_interpretation (query_text )
130+ if interpretation :
131+ st .markdown ("**결과 해석:**" )
132+ st .code (interpretation )
133+ else :
134+ st .warning ("쿼리 텍스트가 문자열이 아닙니다." )
135+ st .text (str (query_text ))
141136
142137 if should_show ("show_result_description" ):
143138 st .markdown ("---" )
144139 st .markdown ("**결과 설명:**" )
145140 result_message = res ["messages" ][- 1 ].content
146141
147- try :
148- sql = LLMResponseParser .extract_sql (result_message )
149- st .code (sql , language = "sql" )
150- except ValueError :
151- st .warning ("SQL 블록을 추출할 수 없습니다." )
152- st .text (result_message )
153-
154- interpretation = LLMResponseParser .extract_interpretation (result_message )
155- if interpretation :
156- st .code (interpretation , language = "plaintext" )
142+ if isinstance (result_message , str ):
143+ try :
144+ sql = LLMResponseParser .extract_sql (result_message )
145+ st .code (sql , language = "sql" )
146+ except ValueError :
147+ st .warning ("SQL 블록을 추출할 수 없습니다." )
148+ st .text (result_message )
149+
150+ interpretation = LLMResponseParser .extract_interpretation (result_message )
151+ if interpretation :
152+ st .code (interpretation , language = "plaintext" )
153+ else :
154+ st .warning ("결과 메시지가 문자열이 아닙니다." )
155+ st .text (str (result_message ))
157156
158157 if should_show ("show_question_reinterpreted_by_ai" ):
159158 st .markdown ("---" )
@@ -173,26 +172,41 @@ def should_show(_key: str) -> bool:
173172 if isinstance (res ["generated_query" ], AIMessage )
174173 else str (res ["generated_query" ])
175174 )
176- sql = LLMResponseParser .extract_sql (sql_raw )
177- df = database .run_sql (sql )
178- st .dataframe (df .head (10 ) if len (df ) > 10 else df )
175+ if isinstance (sql_raw , str ):
176+ sql = LLMResponseParser .extract_sql (sql_raw )
177+ df = database .run_sql (sql )
178+ st .dataframe (df .head (10 ) if len (df ) > 10 else df )
179+ else :
180+ st .error ("SQL 원본이 문자열이 아닙니다." )
179181 except Exception as e :
180182 st .error (f"쿼리 실행 중 오류 발생: { e } " )
181183
182184 if should_show ("show_chart" ):
183185 st .markdown ("---" )
184- df = database .run_sql (sql )
185- st .markdown ("**쿼리 결과 시각화:**" )
186- display_code = DisplayChart (
187- question = res ["refined_input" ].content ,
188- sql = sql ,
189- df_metadata = f"Running df.dtypes gives:\n { df .dtypes } " ,
190- )
191- # plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
192- fig = display_code .get_plotly_figure (
193- plotly_code = display_code .generate_plotly_code (), df = df
194- )
195- st .plotly_chart (fig )
186+ try :
187+ sql_raw = (
188+ res ["generated_query" ].content
189+ if isinstance (res ["generated_query" ], AIMessage )
190+ else str (res ["generated_query" ])
191+ )
192+ if isinstance (sql_raw , str ):
193+ sql = LLMResponseParser .extract_sql (sql_raw )
194+ df = database .run_sql (sql )
195+ st .markdown ("**쿼리 결과 시각화:**" )
196+ display_code = DisplayChart (
197+ question = res ["refined_input" ].content ,
198+ sql = sql ,
199+ df_metadata = f"Running df.dtypes gives:\n { df .dtypes } " ,
200+ )
201+ # plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
202+ fig = display_code .get_plotly_figure (
203+ plotly_code = display_code .generate_plotly_code (), df = df
204+ )
205+ st .plotly_chart (fig )
206+ else :
207+ st .error ("SQL 원본이 문자열이 아닙니다." )
208+ except Exception as e :
209+ st .error (f"차트 생성 중 오류 발생: { e } " )
196210
197211
198212db = ConnectDB ()
0 commit comments