Skip to content

Commit 6a5451c

Browse files
authored
Merge pull request #125 from #123
123 cli output format 지원
2 parents 3a94fb0 + d7e60e6 commit 6a5451c

File tree

9 files changed

+363
-103
lines changed

9 files changed

+363
-103
lines changed

cli/__init__.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,91 @@ def run_streamlit_cli_command(port: int) -> None:
159159

160160
logger.info("Executing 'run-streamlit' command on port %d...", port)
161161
run_streamlit_command(port)
162+
163+
164+
@cli.command(name="query")
165+
@click.argument("question", type=str)
166+
@click.option(
167+
"--database-env",
168+
default="clickhouse",
169+
help="사용할 데이터베이스 환경 (기본값: clickhouse)",
170+
)
171+
@click.option(
172+
"--retriever-name",
173+
default="기본",
174+
help="테이블 검색기 이름 (기본값: 기본)",
175+
)
176+
@click.option(
177+
"--top-n",
178+
type=int,
179+
default=5,
180+
help="검색된 상위 테이블 수 제한 (기본값: 5)",
181+
)
182+
@click.option(
183+
"--device",
184+
default="cpu",
185+
help="LLM 실행에 사용할 디바이스 (기본값: cpu)",
186+
)
187+
@click.option(
188+
"--use-enriched-graph",
189+
is_flag=True,
190+
help="확장된 그래프(프로파일 추출 + 컨텍스트 보강) 사용 여부",
191+
)
192+
def query_command(
193+
question: str,
194+
database_env: str,
195+
retriever_name: str,
196+
top_n: int,
197+
device: str,
198+
use_enriched_graph: bool,
199+
) -> None:
200+
"""
201+
자연어 질문을 SQL 쿼리로 변환하여 출력하는 명령어입니다.
202+
203+
이 명령은 사용자가 입력한 자연어 질문을 받아서 SQL 쿼리로 변환하고,
204+
생성된 SQL 쿼리만을 표준 출력으로 출력합니다.
205+
206+
매개변수:
207+
question (str): SQL로 변환할 자연어 질문
208+
database_env (str): 사용할 데이터베이스 환경
209+
retriever_name (str): 테이블 검색기 이름
210+
top_n (int): 검색된 상위 테이블 수 제한
211+
device (str): LLM 실행에 사용할 디바이스
212+
use_enriched_graph (bool): 확장된 그래프 사용 여부
213+
214+
예시:
215+
lang2sql query "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
216+
lang2sql query "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" --use-enriched-graph
217+
"""
218+
219+
try:
220+
from llm_utils.query_executor import execute_query, extract_sql_from_result
221+
222+
# 공용 함수를 사용하여 쿼리 실행
223+
res = execute_query(
224+
query=question,
225+
database_env=database_env,
226+
retriever_name=retriever_name,
227+
top_n=top_n,
228+
device=device,
229+
use_enriched_graph=use_enriched_graph,
230+
)
231+
232+
# SQL 추출 및 출력
233+
sql = extract_sql_from_result(res)
234+
if sql:
235+
print(sql)
236+
else:
237+
# SQL 추출 실패 시 원본 쿼리 텍스트 출력
238+
generated_query = res.get("generated_query")
239+
if generated_query:
240+
query_text = (
241+
generated_query.content
242+
if hasattr(generated_query, "content")
243+
else str(generated_query)
244+
)
245+
print(query_text)
246+
247+
except Exception as e:
248+
logger.error("쿼리 처리 중 오류 발생: %s", e)
249+
raise

evaluation/gen_answer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tqdm import tqdm
77
import uuid
88

9-
from llm_utils.graph import builder
9+
from llm_utils.graph_utils.basic_graph import builder
1010

1111

1212
def get_eval_result(

interface/lang2sql.py

Lines changed: 81 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@
99

1010
import streamlit as st
1111
from langchain.chains.sql_database.prompt import SQL_PROMPTS
12-
from langchain_core.messages import AIMessage, HumanMessage
12+
from langchain_core.messages import AIMessage
1313

1414
from db_utils import get_db_connector
1515
from db_utils.base_connector import BaseConnector
1616
from llm_utils.connect_db import ConnectDB
1717
from llm_utils.display_chart import DisplayChart
18-
from llm_utils.enriched_graph import builder as enriched_builder
19-
from llm_utils.graph import builder
18+
from llm_utils.query_executor import execute_query as execute_query_common
2019
from llm_utils.llm_response_parser import LLMResponseParser
2120
from llm_utils.token_utils import TokenUtils
21+
from llm_utils.graph_utils.enriched_graph import builder as enriched_builder
22+
from llm_utils.graph_utils.basic_graph import builder
2223

2324

2425
TITLE = "Lang2SQL"
@@ -45,9 +46,8 @@ def execute_query(
4546
"""
4647
자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 Lang2SQL 그래프 인터페이스 함수입니다.
4748
48-
이 함수는 Lang2SQL 파이프라인(graph)을 세션 상태에서 가져오거나 새로 컴파일한 뒤,
49-
사용자의 자연어 질문을 SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다.
50-
내부적으로 LangChain의 `graph.invoke` 메서드를 호출합니다.
49+
이 함수는 공용 execute_query 함수를 호출하여 Lang2SQL 파이프라인을 실행합니다.
50+
Streamlit 세션 상태를 활용하여 그래프를 재사용합니다.
5151
5252
Args:
5353
query (str): 사용자가 입력한 자연어 기반 질문.
@@ -64,27 +64,16 @@ def execute_query(
6464
- "searched_tables": 참조된 테이블 목록 등 추가 정보
6565
"""
6666

67-
graph = st.session_state.get("graph")
68-
if graph is None:
69-
graph_builder = (
70-
enriched_builder if st.session_state.get("use_enriched") else builder
71-
)
72-
graph = graph_builder.compile()
73-
st.session_state["graph"] = graph
74-
75-
res = graph.invoke(
76-
input={
77-
"messages": [HumanMessage(content=query)],
78-
"user_database_env": database_env,
79-
"best_practice_query": "",
80-
"retriever_name": retriever_name,
81-
"top_n": top_n,
82-
"device": device,
83-
}
67+
return execute_query_common(
68+
query=query,
69+
database_env=database_env,
70+
retriever_name=retriever_name,
71+
top_n=top_n,
72+
device=device,
73+
use_enriched_graph=st.session_state.get("use_enriched", False),
74+
session_state=st.session_state,
8475
)
8576

86-
return res
87-
8877

8978
def display_result(
9079
*,
@@ -125,40 +114,50 @@ def should_show(_key: str) -> bool:
125114
if should_show("show_sql"):
126115
st.markdown("---")
127116
generated_query = res.get("generated_query")
128-
query_text = (
129-
generated_query.content
130-
if isinstance(generated_query, AIMessage)
131-
else str(generated_query)
132-
)
117+
if generated_query:
118+
query_text = (
119+
generated_query.content
120+
if isinstance(generated_query, AIMessage)
121+
else str(generated_query)
122+
)
133123

134-
try:
135-
sql = LLMResponseParser.extract_sql(query_text)
136-
st.markdown("**생성된 SQL 쿼리:**")
137-
st.code(sql, language="sql")
138-
except ValueError:
139-
st.warning("SQL 블록을 추출할 수 없습니다.")
140-
st.text(query_text)
141-
142-
interpretation = LLMResponseParser.extract_interpretation(query_text)
143-
if interpretation:
144-
st.markdown("**결과 해석:**")
145-
st.code(interpretation)
124+
# query_text가 문자열인지 확인
125+
if isinstance(query_text, str):
126+
try:
127+
sql = LLMResponseParser.extract_sql(query_text)
128+
st.markdown("**생성된 SQL 쿼리:**")
129+
st.code(sql, language="sql")
130+
except ValueError:
131+
st.warning("SQL 블록을 추출할 수 없습니다.")
132+
st.text(query_text)
133+
134+
interpretation = LLMResponseParser.extract_interpretation(query_text)
135+
if interpretation:
136+
st.markdown("**결과 해석:**")
137+
st.code(interpretation)
138+
else:
139+
st.warning("쿼리 텍스트가 문자열이 아닙니다.")
140+
st.text(str(query_text))
146141

147142
if should_show("show_result_description"):
148143
st.markdown("---")
149144
st.markdown("**결과 설명:**")
150145
result_message = res["messages"][-1].content
151146

152-
try:
153-
sql = LLMResponseParser.extract_sql(result_message)
154-
st.code(sql, language="sql")
155-
except ValueError:
156-
st.warning("SQL 블록을 추출할 수 없습니다.")
157-
st.text(result_message)
158-
159-
interpretation = LLMResponseParser.extract_interpretation(result_message)
160-
if interpretation:
161-
st.code(interpretation, language="plaintext")
147+
if isinstance(result_message, str):
148+
try:
149+
sql = LLMResponseParser.extract_sql(result_message)
150+
st.code(sql, language="sql")
151+
except ValueError:
152+
st.warning("SQL 블록을 추출할 수 없습니다.")
153+
st.text(result_message)
154+
155+
interpretation = LLMResponseParser.extract_interpretation(result_message)
156+
if interpretation:
157+
st.code(interpretation, language="plaintext")
158+
else:
159+
st.warning("결과 메시지가 문자열이 아닙니다.")
160+
st.text(str(result_message))
162161

163162
if should_show("show_question_reinterpreted_by_ai"):
164163
st.markdown("---")
@@ -178,26 +177,41 @@ def should_show(_key: str) -> bool:
178177
if isinstance(res["generated_query"], AIMessage)
179178
else str(res["generated_query"])
180179
)
181-
sql = LLMResponseParser.extract_sql(sql_raw)
182-
df = database.run_sql(sql)
183-
st.dataframe(df.head(10) if len(df) > 10 else df)
180+
if isinstance(sql_raw, str):
181+
sql = LLMResponseParser.extract_sql(sql_raw)
182+
df = database.run_sql(sql)
183+
st.dataframe(df.head(10) if len(df) > 10 else df)
184+
else:
185+
st.error("SQL 원본이 문자열이 아닙니다.")
184186
except Exception as e:
185187
st.error(f"쿼리 실행 중 오류 발생: {e}")
186188

187189
if should_show("show_chart"):
188190
st.markdown("---")
189-
df = database.run_sql(sql)
190-
st.markdown("**쿼리 결과 시각화:**")
191-
display_code = DisplayChart(
192-
question=res["refined_input"].content,
193-
sql=sql,
194-
df_metadata=f"Running df.dtypes gives:\n{df.dtypes}",
195-
)
196-
# plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
197-
fig = display_code.get_plotly_figure(
198-
plotly_code=display_code.generate_plotly_code(), df=df
199-
)
200-
st.plotly_chart(fig)
191+
try:
192+
sql_raw = (
193+
res["generated_query"].content
194+
if isinstance(res["generated_query"], AIMessage)
195+
else str(res["generated_query"])
196+
)
197+
if isinstance(sql_raw, str):
198+
sql = LLMResponseParser.extract_sql(sql_raw)
199+
df = database.run_sql(sql)
200+
st.markdown("**쿼리 결과 시각화:**")
201+
display_code = DisplayChart(
202+
question=res["refined_input"].content,
203+
sql=sql,
204+
df_metadata=f"Running df.dtypes gives:\n{df.dtypes}",
205+
)
206+
# plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
207+
fig = display_code.get_plotly_figure(
208+
plotly_code=display_code.generate_plotly_code(), df=df
209+
)
210+
st.plotly_chart(fig)
211+
else:
212+
st.error("SQL 원본이 문자열이 아닙니다.")
213+
except Exception as e:
214+
st.error(f"차트 생성 중 오류 발생: {e}")
201215

202216

203217
db = get_db_connector()

llm_utils/graph_utils/__init__.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
그래프 관련 유틸리티 모듈입니다.
3+
4+
이 패키지는 Lang2SQL의 워크플로우 그래프 구성과 관련된 모듈들을 포함합니다.
5+
"""
6+
7+
from .base import (
8+
QueryMakerState,
9+
GET_TABLE_INFO,
10+
QUERY_REFINER,
11+
QUERY_MAKER,
12+
PROFILE_EXTRACTION,
13+
CONTEXT_ENRICHMENT,
14+
get_table_info_node,
15+
query_refiner_node,
16+
query_maker_node,
17+
profile_extraction_node,
18+
query_refiner_with_profile_node,
19+
context_enrichment_node,
20+
query_maker_node_with_db_guide,
21+
)
22+
23+
from .basic_graph import builder as basic_builder
24+
from .enriched_graph import builder as enriched_builder
25+
26+
__all__ = [
27+
# 상태 및 노드 식별자
28+
"QueryMakerState",
29+
"GET_TABLE_INFO",
30+
"QUERY_REFINER",
31+
"QUERY_MAKER",
32+
"PROFILE_EXTRACTION",
33+
"CONTEXT_ENRICHMENT",
34+
# 노드 함수들
35+
"get_table_info_node",
36+
"query_refiner_node",
37+
"query_maker_node",
38+
"profile_extraction_node",
39+
"query_refiner_with_profile_node",
40+
"context_enrichment_node",
41+
"query_maker_node_with_db_guide",
42+
# 그래프 빌더들
43+
"basic_builder",
44+
"enriched_builder",
45+
]

llm_utils/graph.py renamed to llm_utils/graph_utils/base.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from langgraph.graph.message import add_messages
77
from langchain.chains.sql_database.prompt import SQL_PROMPTS
88
from pydantic import BaseModel, Field
9-
from .llm_factory import get_llm
9+
from llm_utils.llm_factory import get_llm
1010

1111
from llm_utils.chains import (
1212
query_refiner_chain,
@@ -119,7 +119,7 @@ def context_enrichment_node(state: QueryMakerState):
119119
주요 작업:
120120
- 주어진 질문의 메타데이터 (`question_profile` 및 `searched_tables`)를 활용하여, 질문을 수정하거나 추가 정보를 삽입합니다.
121121
- 질문이 시계열 분석 또는 집계 함수 관련인 경우, 이를 명시적으로 강조합니다 (예: "지난 30일 동안").
122-
- 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: ‘미국’USA).
122+
- 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: '미국''USA').
123123
- 보강된 질문을 출력합니다.
124124
125125
Args:
@@ -207,23 +207,3 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
207207
state["generated_query"] = res.sql
208208
state["messages"].append(res.explanation)
209209
return state
210-
211-
212-
# StateGraph 생성 및 구성
213-
builder = StateGraph(QueryMakerState)
214-
builder.set_entry_point(GET_TABLE_INFO)
215-
216-
# 노드 추가
217-
builder.add_node(GET_TABLE_INFO, get_table_info_node)
218-
builder.add_node(QUERY_REFINER, query_refiner_node)
219-
builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
220-
# builder.add_node(
221-
# QUERY_MAKER, query_maker_node_with_db_guide
222-
# ) # query_maker_node_with_db_guide
223-
224-
# 기본 엣지 설정
225-
builder.add_edge(GET_TABLE_INFO, QUERY_REFINER)
226-
builder.add_edge(QUERY_REFINER, QUERY_MAKER)
227-
228-
# QUERY_MAKER 노드 후 종료
229-
builder.add_edge(QUERY_MAKER, END)

0 commit comments

Comments
 (0)