99
1010import streamlit as st
1111from langchain .chains .sql_database .prompt import SQL_PROMPTS
12- from langchain_core .messages import AIMessage , HumanMessage
12+ from langchain_core .messages import AIMessage
1313
1414from db_utils import get_db_connector
1515from db_utils .base_connector import BaseConnector
1616from llm_utils .connect_db import ConnectDB
1717from llm_utils .display_chart import DisplayChart
18- from llm_utils .enriched_graph import builder as enriched_builder
19- from llm_utils .graph import builder
18+ from llm_utils .query_executor import execute_query as execute_query_common
2019from llm_utils .llm_response_parser import LLMResponseParser
2120from llm_utils .token_utils import TokenUtils
21+ from llm_utils .graph_utils .enriched_graph import builder as enriched_builder
22+ from llm_utils .graph_utils .basic_graph import builder
2223
2324
2425TITLE = "Lang2SQL"
@@ -45,9 +46,8 @@ def execute_query(
4546 """
4647 자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 Lang2SQL 그래프 인터페이스 함수입니다.
4748
48- 이 함수는 Lang2SQL 파이프라인(graph)을 세션 상태에서 가져오거나 새로 컴파일한 뒤,
49- 사용자의 자연어 질문을 SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다.
50- 내부적으로 LangChain의 `graph.invoke` 메서드를 호출합니다.
49+ 이 함수는 공용 execute_query 함수를 호출하여 Lang2SQL 파이프라인을 실행합니다.
50+ Streamlit 세션 상태를 활용하여 그래프를 재사용합니다.
5151
5252 Args:
5353 query (str): 사용자가 입력한 자연어 기반 질문.
@@ -64,27 +64,16 @@ def execute_query(
6464 - "searched_tables": 참조된 테이블 목록 등 추가 정보
6565 """
6666
67- graph = st .session_state .get ("graph" )
68- if graph is None :
69- graph_builder = (
70- enriched_builder if st .session_state .get ("use_enriched" ) else builder
71- )
72- graph = graph_builder .compile ()
73- st .session_state ["graph" ] = graph
74-
75- res = graph .invoke (
76- input = {
77- "messages" : [HumanMessage (content = query )],
78- "user_database_env" : database_env ,
79- "best_practice_query" : "" ,
80- "retriever_name" : retriever_name ,
81- "top_n" : top_n ,
82- "device" : device ,
83- }
67+ return execute_query_common (
68+ query = query ,
69+ database_env = database_env ,
70+ retriever_name = retriever_name ,
71+ top_n = top_n ,
72+ device = device ,
73+ use_enriched_graph = st .session_state .get ("use_enriched" , False ),
74+ session_state = st .session_state ,
8475 )
8576
86- return res
87-
8877
8978def display_result (
9079 * ,
@@ -125,40 +114,50 @@ def should_show(_key: str) -> bool:
125114 if should_show ("show_sql" ):
126115 st .markdown ("---" )
127116 generated_query = res .get ("generated_query" )
128- query_text = (
129- generated_query .content
130- if isinstance (generated_query , AIMessage )
131- else str (generated_query )
132- )
117+ if generated_query :
118+ query_text = (
119+ generated_query .content
120+ if isinstance (generated_query , AIMessage )
121+ else str (generated_query )
122+ )
133123
134- try :
135- sql = LLMResponseParser .extract_sql (query_text )
136- st .markdown ("**생성된 SQL 쿼리:**" )
137- st .code (sql , language = "sql" )
138- except ValueError :
139- st .warning ("SQL 블록을 추출할 수 없습니다." )
140- st .text (query_text )
141-
142- interpretation = LLMResponseParser .extract_interpretation (query_text )
143- if interpretation :
144- st .markdown ("**결과 해석:**" )
145- st .code (interpretation )
124+ # query_text가 문자열인지 확인
125+ if isinstance (query_text , str ):
126+ try :
127+ sql = LLMResponseParser .extract_sql (query_text )
128+ st .markdown ("**생성된 SQL 쿼리:**" )
129+ st .code (sql , language = "sql" )
130+ except ValueError :
131+ st .warning ("SQL 블록을 추출할 수 없습니다." )
132+ st .text (query_text )
133+
134+ interpretation = LLMResponseParser .extract_interpretation (query_text )
135+ if interpretation :
136+ st .markdown ("**결과 해석:**" )
137+ st .code (interpretation )
138+ else :
139+ st .warning ("쿼리 텍스트가 문자열이 아닙니다." )
140+ st .text (str (query_text ))
146141
147142 if should_show ("show_result_description" ):
148143 st .markdown ("---" )
149144 st .markdown ("**결과 설명:**" )
150145 result_message = res ["messages" ][- 1 ].content
151146
152- try :
153- sql = LLMResponseParser .extract_sql (result_message )
154- st .code (sql , language = "sql" )
155- except ValueError :
156- st .warning ("SQL 블록을 추출할 수 없습니다." )
157- st .text (result_message )
158-
159- interpretation = LLMResponseParser .extract_interpretation (result_message )
160- if interpretation :
161- st .code (interpretation , language = "plaintext" )
147+ if isinstance (result_message , str ):
148+ try :
149+ sql = LLMResponseParser .extract_sql (result_message )
150+ st .code (sql , language = "sql" )
151+ except ValueError :
152+ st .warning ("SQL 블록을 추출할 수 없습니다." )
153+ st .text (result_message )
154+
155+ interpretation = LLMResponseParser .extract_interpretation (result_message )
156+ if interpretation :
157+ st .code (interpretation , language = "plaintext" )
158+ else :
159+ st .warning ("결과 메시지가 문자열이 아닙니다." )
160+ st .text (str (result_message ))
162161
163162 if should_show ("show_question_reinterpreted_by_ai" ):
164163 st .markdown ("---" )
@@ -178,26 +177,41 @@ def should_show(_key: str) -> bool:
178177 if isinstance (res ["generated_query" ], AIMessage )
179178 else str (res ["generated_query" ])
180179 )
181- sql = LLMResponseParser .extract_sql (sql_raw )
182- df = database .run_sql (sql )
183- st .dataframe (df .head (10 ) if len (df ) > 10 else df )
180+ if isinstance (sql_raw , str ):
181+ sql = LLMResponseParser .extract_sql (sql_raw )
182+ df = database .run_sql (sql )
183+ st .dataframe (df .head (10 ) if len (df ) > 10 else df )
184+ else :
185+ st .error ("SQL 원본이 문자열이 아닙니다." )
184186 except Exception as e :
185187 st .error (f"쿼리 실행 중 오류 발생: { e } " )
186188
187189 if should_show ("show_chart" ):
188190 st .markdown ("---" )
189- df = database .run_sql (sql )
190- st .markdown ("**쿼리 결과 시각화:**" )
191- display_code = DisplayChart (
192- question = res ["refined_input" ].content ,
193- sql = sql ,
194- df_metadata = f"Running df.dtypes gives:\n { df .dtypes } " ,
195- )
196- # plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
197- fig = display_code .get_plotly_figure (
198- plotly_code = display_code .generate_plotly_code (), df = df
199- )
200- st .plotly_chart (fig )
191+ try :
192+ sql_raw = (
193+ res ["generated_query" ].content
194+ if isinstance (res ["generated_query" ], AIMessage )
195+ else str (res ["generated_query" ])
196+ )
197+ if isinstance (sql_raw , str ):
198+ sql = LLMResponseParser .extract_sql (sql_raw )
199+ df = database .run_sql (sql )
200+ st .markdown ("**쿼리 결과 시각화:**" )
201+ display_code = DisplayChart (
202+ question = res ["refined_input" ].content ,
203+ sql = sql ,
204+ df_metadata = f"Running df.dtypes gives:\n { df .dtypes } " ,
205+ )
206+ # plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
207+ fig = display_code .get_plotly_figure (
208+ plotly_code = display_code .generate_plotly_code (), df = df
209+ )
210+ st .plotly_chart (fig )
211+ else :
212+ st .error ("SQL 원본이 문자열이 아닙니다." )
213+ except Exception as e :
214+ st .error (f"차트 생성 중 오류 발생: { e } " )
201215
202216
203217db = get_db_connector ()
0 commit comments