Skip to content

Commit 52a9871

Browse files
committed
feat: cli로 lang2sql 을 사용할수있도록 수정
- main.py 파일 삭제 - basic_graph 및 enriched_graph 그래프 구성 분리 - query cli 추가 - cli output foramt 추가
1 parent 717bfb1 commit 52a9871

File tree

8 files changed

+186
-38
lines changed

8 files changed

+186
-38
lines changed

cli/__init__.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,104 @@ def run_streamlit_cli_command(port: int) -> None:
159159

160160
logger.info("Executing 'run-streamlit' command on port %d...", port)
161161
run_streamlit_command(port)
162+
163+
164+
@cli.command(name="query")
165+
@click.argument("question", type=str)
166+
@click.option(
167+
"--database-env",
168+
default="clickhouse",
169+
help="사용할 데이터베이스 환경 (기본값: clickhouse)",
170+
)
171+
@click.option(
172+
"--retriever-name",
173+
default="기본",
174+
help="테이블 검색기 이름 (기본값: 기본)",
175+
)
176+
@click.option(
177+
"--top-n",
178+
type=int,
179+
default=5,
180+
help="검색된 상위 테이블 수 제한 (기본값: 5)",
181+
)
182+
@click.option(
183+
"--device",
184+
default="cpu",
185+
help="LLM 실행에 사용할 디바이스 (기본값: cpu)",
186+
)
187+
@click.option(
188+
"--use-enriched-graph",
189+
is_flag=True,
190+
help="확장된 그래프(프로파일 추출 + 컨텍스트 보강) 사용 여부",
191+
)
192+
def query_command(
193+
question: str,
194+
database_env: str,
195+
retriever_name: str,
196+
top_n: int,
197+
device: str,
198+
use_enriched_graph: bool,
199+
) -> None:
200+
"""
201+
자연어 질문을 SQL 쿼리로 변환하여 출력하는 명령어입니다.
202+
203+
이 명령은 사용자가 입력한 자연어 질문을 받아서 SQL 쿼리로 변환하고,
204+
생성된 SQL 쿼리만을 표준 출력으로 출력합니다.
205+
206+
매개변수:
207+
question (str): SQL로 변환할 자연어 질문
208+
database_env (str): 사용할 데이터베이스 환경
209+
retriever_name (str): 테이블 검색기 이름
210+
top_n (int): 검색된 상위 테이블 수 제한
211+
device (str): LLM 실행에 사용할 디바이스
212+
use_enriched_graph (bool): 확장된 그래프 사용 여부
213+
214+
예시:
215+
lang2sql query "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
216+
lang2sql query "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" --use-enriched-graph
217+
"""
218+
219+
try:
220+
if use_enriched_graph:
221+
from llm_utils.graph_utils.enriched_graph import builder
222+
else:
223+
from llm_utils.graph_utils.basic_graph import builder
224+
from llm_utils.llm_response_parser import LLMResponseParser
225+
from langchain_core.messages import HumanMessage
226+
227+
logger.info("Processing query: %s", question)
228+
logger.info("Using %s graph", "enriched" if use_enriched_graph else "basic")
229+
230+
# 그래프 컴파일 및 실행
231+
graph = builder.compile()
232+
res = graph.invoke(
233+
input={
234+
"messages": [HumanMessage(content=question)],
235+
"user_database_env": database_env,
236+
"best_practice_query": "",
237+
"retriever_name": retriever_name,
238+
"top_n": top_n,
239+
"device": device,
240+
}
241+
)
242+
243+
# SQL 추출 및 출력
244+
generated_query = res.get("generated_query")
245+
if generated_query:
246+
query_text = (
247+
generated_query.content
248+
if hasattr(generated_query, "content")
249+
else str(generated_query)
250+
)
251+
try:
252+
sql = LLMResponseParser.extract_sql(query_text)
253+
print(sql)
254+
except ValueError:
255+
logger.error("SQL을 추출할 수 없습니다.")
256+
print(query_text)
257+
else:
258+
logger.error("생성된 쿼리가 없습니다.")
259+
260+
except Exception as e:
261+
logger.error("쿼리 처리 중 오류 발생: %s", e)
262+
raise

evaluation/gen_answer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tqdm import tqdm
77
import uuid
88

9-
from llm_utils.graph import builder
9+
from llm_utils.graph_utils.basic_graph import builder
1010

1111

1212
def get_eval_result(

interface/lang2sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
from llm_utils.connect_db import ConnectDB
1313
from llm_utils.display_chart import DisplayChart
14-
from llm_utils.enriched_graph import builder as enriched_builder
15-
from llm_utils.graph import builder
14+
from llm_utils.graph_utils.enriched_graph import builder as enriched_builder
15+
from llm_utils.graph_utils.basic_graph import builder
1616
from llm_utils.llm_response_parser import LLMResponseParser
1717
from llm_utils.token_utils import TokenUtils
1818

llm_utils/graph_utils/__init__.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
그래프 관련 유틸리티 모듈입니다.
3+
4+
이 패키지는 Lang2SQL의 워크플로우 그래프 구성과 관련된 모듈들을 포함합니다.
5+
"""
6+
7+
from .base import (
8+
QueryMakerState,
9+
GET_TABLE_INFO,
10+
QUERY_REFINER,
11+
QUERY_MAKER,
12+
PROFILE_EXTRACTION,
13+
CONTEXT_ENRICHMENT,
14+
get_table_info_node,
15+
query_refiner_node,
16+
query_maker_node,
17+
profile_extraction_node,
18+
query_refiner_with_profile_node,
19+
context_enrichment_node,
20+
query_maker_node_with_db_guide,
21+
)
22+
23+
from .basic_graph import builder as basic_builder
24+
from .enriched_graph import builder as enriched_builder
25+
26+
__all__ = [
27+
# 상태 및 노드 식별자
28+
"QueryMakerState",
29+
"GET_TABLE_INFO",
30+
"QUERY_REFINER",
31+
"QUERY_MAKER",
32+
"PROFILE_EXTRACTION",
33+
"CONTEXT_ENRICHMENT",
34+
# 노드 함수들
35+
"get_table_info_node",
36+
"query_refiner_node",
37+
"query_maker_node",
38+
"profile_extraction_node",
39+
"query_refiner_with_profile_node",
40+
"context_enrichment_node",
41+
"query_maker_node_with_db_guide",
42+
# 그래프 빌더들
43+
"basic_builder",
44+
"enriched_builder",
45+
]
Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from langgraph.graph.message import add_messages
77
from langchain.chains.sql_database.prompt import SQL_PROMPTS
88
from pydantic import BaseModel, Field
9-
from .llm_factory import get_llm
9+
from llm_utils.llm_factory import get_llm
1010

1111
from llm_utils.chains import (
1212
query_refiner_chain,
@@ -119,7 +119,7 @@ def context_enrichment_node(state: QueryMakerState):
119119
주요 작업:
120120
- 주어진 질문의 메타데이터 (`question_profile` 및 `searched_tables`)를 활용하여, 질문을 수정하거나 추가 정보를 삽입합니다.
121121
- 질문이 시계열 분석 또는 집계 함수 관련인 경우, 이를 명시적으로 강조합니다 (예: "지난 30일 동안").
122-
- 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: ‘미국’USA).
122+
- 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: '미국''USA').
123123
- 보강된 질문을 출력합니다.
124124
125125
Args:
@@ -207,23 +207,3 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
207207
state["generated_query"] = res.sql
208208
state["messages"].append(res.explanation)
209209
return state
210-
211-
212-
# StateGraph 생성 및 구성
213-
builder = StateGraph(QueryMakerState)
214-
builder.set_entry_point(GET_TABLE_INFO)
215-
216-
# 노드 추가
217-
builder.add_node(GET_TABLE_INFO, get_table_info_node)
218-
builder.add_node(QUERY_REFINER, query_refiner_node)
219-
builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
220-
# builder.add_node(
221-
# QUERY_MAKER, query_maker_node_with_db_guide
222-
# ) # query_maker_node_with_db_guide
223-
224-
# 기본 엣지 설정
225-
builder.add_edge(GET_TABLE_INFO, QUERY_REFINER)
226-
builder.add_edge(QUERY_REFINER, QUERY_MAKER)
227-
228-
# QUERY_MAKER 노드 후 종료
229-
builder.add_edge(QUERY_MAKER, END)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import json
2+
3+
from langgraph.graph import StateGraph, END
4+
from llm_utils.graph_utils.base import (
5+
QueryMakerState,
6+
GET_TABLE_INFO,
7+
QUERY_REFINER,
8+
QUERY_MAKER,
9+
get_table_info_node,
10+
query_refiner_node,
11+
query_maker_node,
12+
)
13+
14+
"""
15+
기본 워크플로우를 위한 StateGraph 구성입니다.
16+
GET_TABLE_INFO -> QUERY_REFINER -> QUERY_MAKER 순서로 실행됩니다.
17+
"""
18+
19+
# StateGraph 생성 및 구성
20+
builder = StateGraph(QueryMakerState)
21+
builder.set_entry_point(GET_TABLE_INFO)
22+
23+
# 노드 추가
24+
builder.add_node(GET_TABLE_INFO, get_table_info_node)
25+
builder.add_node(QUERY_REFINER, query_refiner_node)
26+
builder.add_node(QUERY_MAKER, query_maker_node)
27+
28+
# 기본 엣지 설정
29+
builder.add_edge(GET_TABLE_INFO, QUERY_REFINER)
30+
builder.add_edge(QUERY_REFINER, QUERY_MAKER)
31+
32+
# QUERY_MAKER 노드 후 종료
33+
builder.add_edge(QUERY_MAKER, END)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22

33
from langgraph.graph import StateGraph, END
4-
from llm_utils.graph import (
4+
from llm_utils.graph_utils.base import (
55
QueryMakerState,
66
GET_TABLE_INFO,
77
PROFILE_EXTRACTION,
@@ -26,8 +26,8 @@
2626

2727
# 노드 추가
2828
builder.add_node(GET_TABLE_INFO, get_table_info_node)
29-
builder.add_node(QUERY_REFINER, query_refiner_with_profile_node)
3029
builder.add_node(PROFILE_EXTRACTION, profile_extraction_node)
30+
builder.add_node(QUERY_REFINER, query_refiner_with_profile_node)
3131
builder.add_node(CONTEXT_ENRICHMENT, context_enrichment_node)
3232
builder.add_node(QUERY_MAKER, query_maker_node)
3333

main.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)