Skip to content

Commit d7e60e6

Browse files
committed
refactor: 쿼리 실행을 위한 공용 모듈 추가 및 기존 코드 리팩토링
- llm_utils/query_executor.py 파일 추가: CLI와 Streamlit에서 공통으로 사용할 수 있는 쿼리 실행 함수 구현 - query_command 함수에서 그래프 실행 로직을 execute_query 함수로 변경 - lang2sql.py에서 그래프 실행 로직을 execute_query 함수로 통합하여 코드 간소화
1 parent 52a9871 commit d7e60e6

File tree

3 files changed

+214
-102
lines changed

3 files changed

+214
-102
lines changed

cli/__init__.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -217,45 +217,32 @@ def query_command(
217217
"""
218218

219219
try:
220-
if use_enriched_graph:
221-
from llm_utils.graph_utils.enriched_graph import builder
222-
else:
223-
from llm_utils.graph_utils.basic_graph import builder
224-
from llm_utils.llm_response_parser import LLMResponseParser
225-
from langchain_core.messages import HumanMessage
226-
227-
logger.info("Processing query: %s", question)
228-
logger.info("Using %s graph", "enriched" if use_enriched_graph else "basic")
229-
230-
# 그래프 컴파일 및 실행
231-
graph = builder.compile()
232-
res = graph.invoke(
233-
input={
234-
"messages": [HumanMessage(content=question)],
235-
"user_database_env": database_env,
236-
"best_practice_query": "",
237-
"retriever_name": retriever_name,
238-
"top_n": top_n,
239-
"device": device,
240-
}
220+
from llm_utils.query_executor import execute_query, extract_sql_from_result
221+
222+
# 공용 함수를 사용하여 쿼리 실행
223+
res = execute_query(
224+
query=question,
225+
database_env=database_env,
226+
retriever_name=retriever_name,
227+
top_n=top_n,
228+
device=device,
229+
use_enriched_graph=use_enriched_graph,
241230
)
242231

243232
# SQL 추출 및 출력
244-
generated_query = res.get("generated_query")
245-
if generated_query:
246-
query_text = (
247-
generated_query.content
248-
if hasattr(generated_query, "content")
249-
else str(generated_query)
250-
)
251-
try:
252-
sql = LLMResponseParser.extract_sql(query_text)
253-
print(sql)
254-
except ValueError:
255-
logger.error("SQL을 추출할 수 없습니다.")
256-
print(query_text)
233+
sql = extract_sql_from_result(res)
234+
if sql:
235+
print(sql)
257236
else:
258-
logger.error("생성된 쿼리가 없습니다.")
237+
# SQL 추출 실패 시 원본 쿼리 텍스트 출력
238+
generated_query = res.get("generated_query")
239+
if generated_query:
240+
query_text = (
241+
generated_query.content
242+
if hasattr(generated_query, "content")
243+
else str(generated_query)
244+
)
245+
print(query_text)
259246

260247
except Exception as e:
261248
logger.error("쿼리 처리 중 오류 발생: %s", e)

interface/lang2sql.py

Lines changed: 81 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77

88
import streamlit as st
99
from langchain.chains.sql_database.prompt import SQL_PROMPTS
10-
from langchain_core.messages import AIMessage, HumanMessage
10+
from langchain_core.messages import AIMessage
1111

1212
from llm_utils.connect_db import ConnectDB
1313
from llm_utils.display_chart import DisplayChart
14-
from llm_utils.graph_utils.enriched_graph import builder as enriched_builder
15-
from llm_utils.graph_utils.basic_graph import builder
14+
from llm_utils.query_executor import execute_query as execute_query_common
1615
from llm_utils.llm_response_parser import LLMResponseParser
1716
from llm_utils.token_utils import TokenUtils
17+
from llm_utils.graph_utils.enriched_graph import builder as enriched_builder
18+
from llm_utils.graph_utils.basic_graph import builder
1819

1920
TITLE = "Lang2SQL"
2021
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
@@ -40,9 +41,8 @@ def execute_query(
4041
"""
4142
자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 Lang2SQL 그래프 인터페이스 함수입니다.
4243
43-
이 함수는 Lang2SQL 파이프라인(graph)을 세션 상태에서 가져오거나 새로 컴파일한 뒤,
44-
사용자의 자연어 질문을 SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다.
45-
내부적으로 LangChain의 `graph.invoke` 메서드를 호출합니다.
44+
이 함수는 공용 execute_query 함수를 호출하여 Lang2SQL 파이프라인을 실행합니다.
45+
Streamlit 세션 상태를 활용하여 그래프를 재사용합니다.
4646
4747
Args:
4848
query (str): 사용자가 입력한 자연어 기반 질문.
@@ -59,27 +59,16 @@ def execute_query(
5959
- "searched_tables": 참조된 테이블 목록 등 추가 정보
6060
"""
6161

62-
graph = st.session_state.get("graph")
63-
if graph is None:
64-
graph_builder = (
65-
enriched_builder if st.session_state.get("use_enriched") else builder
66-
)
67-
graph = graph_builder.compile()
68-
st.session_state["graph"] = graph
69-
70-
res = graph.invoke(
71-
input={
72-
"messages": [HumanMessage(content=query)],
73-
"user_database_env": database_env,
74-
"best_practice_query": "",
75-
"retriever_name": retriever_name,
76-
"top_n": top_n,
77-
"device": device,
78-
}
62+
return execute_query_common(
63+
query=query,
64+
database_env=database_env,
65+
retriever_name=retriever_name,
66+
top_n=top_n,
67+
device=device,
68+
use_enriched_graph=st.session_state.get("use_enriched", False),
69+
session_state=st.session_state,
7970
)
8071

81-
return res
82-
8372

8473
def display_result(
8574
*,
@@ -120,40 +109,50 @@ def should_show(_key: str) -> bool:
120109
if should_show("show_sql"):
121110
st.markdown("---")
122111
generated_query = res.get("generated_query")
123-
query_text = (
124-
generated_query.content
125-
if isinstance(generated_query, AIMessage)
126-
else str(generated_query)
127-
)
112+
if generated_query:
113+
query_text = (
114+
generated_query.content
115+
if isinstance(generated_query, AIMessage)
116+
else str(generated_query)
117+
)
128118

129-
try:
130-
sql = LLMResponseParser.extract_sql(query_text)
131-
st.markdown("**생성된 SQL 쿼리:**")
132-
st.code(sql, language="sql")
133-
except ValueError:
134-
st.warning("SQL 블록을 추출할 수 없습니다.")
135-
st.text(query_text)
136-
137-
interpretation = LLMResponseParser.extract_interpretation(query_text)
138-
if interpretation:
139-
st.markdown("**결과 해석:**")
140-
st.code(interpretation)
119+
# query_text가 문자열인지 확인
120+
if isinstance(query_text, str):
121+
try:
122+
sql = LLMResponseParser.extract_sql(query_text)
123+
st.markdown("**생성된 SQL 쿼리:**")
124+
st.code(sql, language="sql")
125+
except ValueError:
126+
st.warning("SQL 블록을 추출할 수 없습니다.")
127+
st.text(query_text)
128+
129+
interpretation = LLMResponseParser.extract_interpretation(query_text)
130+
if interpretation:
131+
st.markdown("**결과 해석:**")
132+
st.code(interpretation)
133+
else:
134+
st.warning("쿼리 텍스트가 문자열이 아닙니다.")
135+
st.text(str(query_text))
141136

142137
if should_show("show_result_description"):
143138
st.markdown("---")
144139
st.markdown("**결과 설명:**")
145140
result_message = res["messages"][-1].content
146141

147-
try:
148-
sql = LLMResponseParser.extract_sql(result_message)
149-
st.code(sql, language="sql")
150-
except ValueError:
151-
st.warning("SQL 블록을 추출할 수 없습니다.")
152-
st.text(result_message)
153-
154-
interpretation = LLMResponseParser.extract_interpretation(result_message)
155-
if interpretation:
156-
st.code(interpretation, language="plaintext")
142+
if isinstance(result_message, str):
143+
try:
144+
sql = LLMResponseParser.extract_sql(result_message)
145+
st.code(sql, language="sql")
146+
except ValueError:
147+
st.warning("SQL 블록을 추출할 수 없습니다.")
148+
st.text(result_message)
149+
150+
interpretation = LLMResponseParser.extract_interpretation(result_message)
151+
if interpretation:
152+
st.code(interpretation, language="plaintext")
153+
else:
154+
st.warning("결과 메시지가 문자열이 아닙니다.")
155+
st.text(str(result_message))
157156

158157
if should_show("show_question_reinterpreted_by_ai"):
159158
st.markdown("---")
@@ -173,26 +172,41 @@ def should_show(_key: str) -> bool:
173172
if isinstance(res["generated_query"], AIMessage)
174173
else str(res["generated_query"])
175174
)
176-
sql = LLMResponseParser.extract_sql(sql_raw)
177-
df = database.run_sql(sql)
178-
st.dataframe(df.head(10) if len(df) > 10 else df)
175+
if isinstance(sql_raw, str):
176+
sql = LLMResponseParser.extract_sql(sql_raw)
177+
df = database.run_sql(sql)
178+
st.dataframe(df.head(10) if len(df) > 10 else df)
179+
else:
180+
st.error("SQL 원본이 문자열이 아닙니다.")
179181
except Exception as e:
180182
st.error(f"쿼리 실행 중 오류 발생: {e}")
181183

182184
if should_show("show_chart"):
183185
st.markdown("---")
184-
df = database.run_sql(sql)
185-
st.markdown("**쿼리 결과 시각화:**")
186-
display_code = DisplayChart(
187-
question=res["refined_input"].content,
188-
sql=sql,
189-
df_metadata=f"Running df.dtypes gives:\n{df.dtypes}",
190-
)
191-
# plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
192-
fig = display_code.get_plotly_figure(
193-
plotly_code=display_code.generate_plotly_code(), df=df
194-
)
195-
st.plotly_chart(fig)
186+
try:
187+
sql_raw = (
188+
res["generated_query"].content
189+
if isinstance(res["generated_query"], AIMessage)
190+
else str(res["generated_query"])
191+
)
192+
if isinstance(sql_raw, str):
193+
sql = LLMResponseParser.extract_sql(sql_raw)
194+
df = database.run_sql(sql)
195+
st.markdown("**쿼리 결과 시각화:**")
196+
display_code = DisplayChart(
197+
question=res["refined_input"].content,
198+
sql=sql,
199+
df_metadata=f"Running df.dtypes gives:\n{df.dtypes}",
200+
)
201+
# plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
202+
fig = display_code.get_plotly_figure(
203+
plotly_code=display_code.generate_plotly_code(), df=df
204+
)
205+
st.plotly_chart(fig)
206+
else:
207+
st.error("SQL 원본이 문자열이 아닙니다.")
208+
except Exception as e:
209+
st.error(f"차트 생성 중 오류 발생: {e}")
196210

197211

198212
db = ConnectDB()

llm_utils/query_executor.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
Lang2SQL 쿼리 실행을 위한 공용 모듈입니다.
3+
4+
이 모듈은 CLI와 Streamlit 인터페이스에서 공통으로 사용할 수 있는
5+
쿼리 실행 함수를 제공합니다.
6+
"""
7+
8+
import logging
9+
from typing import Dict, Any, Optional, Union
10+
11+
from langchain_core.messages import HumanMessage
12+
13+
from llm_utils.graph_utils.enriched_graph import builder as enriched_builder
14+
from llm_utils.graph_utils.basic_graph import builder as basic_builder
15+
from llm_utils.llm_response_parser import LLMResponseParser
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def execute_query(
21+
*,
22+
query: str,
23+
database_env: str,
24+
retriever_name: str = "기본",
25+
top_n: int = 5,
26+
device: str = "cpu",
27+
use_enriched_graph: bool = False,
28+
session_state: Optional[Union[Dict[str, Any], Any]] = None,
29+
) -> Dict[str, Any]:
30+
"""
31+
자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 공용 함수입니다.
32+
33+
이 함수는 Lang2SQL 파이프라인(graph)을 사용하여 사용자의 자연어 질문을
34+
SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다.
35+
CLI와 Streamlit 인터페이스에서 공통으로 사용할 수 있습니다.
36+
37+
Args:
38+
query (str): 사용자가 입력한 자연어 기반 질문.
39+
database_env (str): 사용할 데이터베이스 환경 이름 또는 키 (예: "dev", "prod").
40+
retriever_name (str, optional): 테이블 검색기 이름. 기본값은 "기본".
41+
top_n (int, optional): 검색된 상위 테이블 수 제한. 기본값은 5.
42+
device (str, optional): LLM 실행에 사용할 디바이스 ("cpu" 또는 "cuda"). 기본값은 "cpu".
43+
use_enriched_graph (bool, optional): 확장된 그래프 사용 여부. 기본값은 False.
44+
session_state (Optional[Union[Dict[str, Any], Any]], optional): Streamlit 세션 상태 (Streamlit에서만 사용).
45+
46+
Returns:
47+
Dict[str, Any]: 다음 정보를 포함한 Lang2SQL 실행 결과 딕셔너리:
48+
- "generated_query": 생성된 SQL 쿼리 (`AIMessage`)
49+
- "messages": 전체 LLM 응답 메시지 목록
50+
- "refined_input": AI가 재구성한 입력 질문
51+
- "searched_tables": 참조된 테이블 목록 등 추가 정보
52+
"""
53+
54+
logger.info("Processing query: %s", query)
55+
logger.info("Using %s graph", "enriched" if use_enriched_graph else "basic")
56+
57+
# 그래프 선택 및 컴파일
58+
if session_state is not None:
59+
# Streamlit 환경: 세션 상태에서 그래프 재사용
60+
graph = session_state.get("graph")
61+
if graph is None:
62+
graph_builder = enriched_builder if use_enriched_graph else basic_builder
63+
graph = graph_builder.compile()
64+
session_state["graph"] = graph
65+
else:
66+
# CLI 환경: 매번 새로운 그래프 컴파일
67+
graph_builder = enriched_builder if use_enriched_graph else basic_builder
68+
graph = graph_builder.compile()
69+
70+
# 그래프 실행
71+
res = graph.invoke(
72+
input={
73+
"messages": [HumanMessage(content=query)],
74+
"user_database_env": database_env,
75+
"best_practice_query": "",
76+
"retriever_name": retriever_name,
77+
"top_n": top_n,
78+
"device": device,
79+
}
80+
)
81+
82+
return res
83+
84+
85+
def extract_sql_from_result(res: Dict[str, Any]) -> Optional[str]:
86+
"""
87+
Lang2SQL 실행 결과에서 SQL 쿼리를 추출합니다.
88+
89+
Args:
90+
res (Dict[str, Any]): execute_query 함수의 반환 결과
91+
92+
Returns:
93+
Optional[str]: 추출된 SQL 쿼리 문자열. 추출 실패 시 None
94+
"""
95+
generated_query = res.get("generated_query")
96+
if not generated_query:
97+
logger.error("생성된 쿼리가 없습니다.")
98+
return None
99+
100+
query_text = (
101+
generated_query.content
102+
if hasattr(generated_query, "content")
103+
else str(generated_query)
104+
)
105+
106+
try:
107+
sql = LLMResponseParser.extract_sql(query_text)
108+
return sql
109+
except ValueError:
110+
logger.error("SQL을 추출할 수 없습니다.")
111+
return None

0 commit comments

Comments
 (0)