Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,19 @@ def run_streamlit_cli_command(port: int) -> None:
is_flag=True,
help="확장된 그래프(프로파일 추출 + 컨텍스트 보강) 사용 여부",
)
@click.option(
"--use-simplified-graph",
is_flag=True,
help="단순화된 그래프(QUERY_REFINER 제거) 사용 여부",
)
def query_command(
question: str,
database_env: str,
retriever_name: str,
top_n: int,
device: str,
use_enriched_graph: bool,
use_simplified_graph: bool,
) -> None:
"""
자연어 질문을 SQL 쿼리로 변환하여 출력하는 명령어입니다.
Expand Down Expand Up @@ -227,6 +233,7 @@ def query_command(
top_n=top_n,
device=device,
use_enriched_graph=use_enriched_graph,
use_simplified_graph=use_simplified_graph,
)

# SQL 추출 및 출력
Expand Down
42 changes: 34 additions & 8 deletions interface/lang2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from llm_utils.token_utils import TokenUtils
from llm_utils.graph_utils.enriched_graph import builder as enriched_builder
from llm_utils.graph_utils.basic_graph import builder
from llm_utils.graph_utils.simplified_graph import builder as simplified_builder


TITLE = "Lang2SQL"
Expand Down Expand Up @@ -71,6 +72,7 @@ def execute_query(
top_n=top_n,
device=device,
use_enriched_graph=st.session_state.get("use_enriched", False),
use_simplified_graph=st.session_state.get("use_simplified", False),
session_state=st.session_state,
)

Expand Down Expand Up @@ -219,29 +221,53 @@ def should_show(_key: str) -> bool:
st.title(TITLE)

# 워크플로우 선택(UI)
st.sidebar.markdown("### 워크플로우 선택")
use_enriched = st.sidebar.checkbox(
"프로파일 추출 & 컨텍스트 보강 워크플로우 사용", value=False
)
use_simplified = st.sidebar.checkbox(
"단순화된 워크플로우 사용 (QUERY_REFINER 제거)", value=False
)

# 세션 상태 초기화
if (
"graph" not in st.session_state
or st.session_state.get("use_enriched") != use_enriched
or st.session_state.get("use_simplified") != use_simplified
):
graph_builder = enriched_builder if use_enriched else builder
st.session_state["graph"] = graph_builder.compile()
# 그래프 선택 로직
if use_simplified:
graph_builder = simplified_builder
graph_type = "단순화된"
elif use_enriched:
graph_builder = enriched_builder
graph_type = "확장된"
else:
graph_builder = builder
graph_type = "기본"

# 프로파일 추출 & 컨텍스트 보강 그래프
st.session_state["graph"] = graph_builder.compile()
st.session_state["use_enriched"] = use_enriched
st.info("Lang2SQL이 성공적으로 시작되었습니다.")
st.session_state["use_simplified"] = use_simplified
st.info(f"Lang2SQL이 성공적으로 시작되었습니다. ({graph_type} 워크플로우)")

# 새로고침 버튼 추가
if st.sidebar.button("Lang2SQL 새로고침"):
graph_builder = (
enriched_builder if st.session_state.get("use_enriched") else builder
)
# 그래프 선택 로직
if st.session_state.get("use_simplified"):
graph_builder = simplified_builder
graph_type = "단순화된"
elif st.session_state.get("use_enriched"):
graph_builder = enriched_builder
graph_type = "확장된"
else:
graph_builder = builder
graph_type = "기본"

st.session_state["graph"] = graph_builder.compile()
st.sidebar.success("Lang2SQL이 성공적으로 새로고침되었습니다.")
st.sidebar.success(
f"Lang2SQL이 성공적으로 새로고침되었습니다. ({graph_type} 워크플로우)"
)

user_query = st.text_area(
"쿼리를 입력하세요:",
Expand Down
4 changes: 4 additions & 0 deletions llm_utils/graph_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
query_refiner_with_profile_node,
context_enrichment_node,
query_maker_node_with_db_guide,
query_maker_node_without_refiner,
)

from .basic_graph import builder as basic_builder
from .enriched_graph import builder as enriched_builder
from .simplified_graph import builder as simplified_builder

__all__ = [
# 상태 및 노드 식별자
Expand All @@ -39,7 +41,9 @@
"query_refiner_with_profile_node",
"context_enrichment_node",
"query_maker_node_with_db_guide",
"query_maker_node_without_refiner",
# 그래프 빌더들
"basic_builder",
"enriched_builder",
"simplified_builder",
]
43 changes: 41 additions & 2 deletions llm_utils/graph_utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,21 @@ def context_enrichment_node(state: QueryMakerState):
searched_tables = state["searched_tables"]
searched_tables_json = json.dumps(searched_tables, ensure_ascii=False, indent=2)

question_profile = state["question_profile"].model_dump()
# question_profile이 BaseModel인 경우 model_dump() 사용, dict인 경우 그대로 사용
if hasattr(state["question_profile"], "model_dump"):
question_profile = state["question_profile"].model_dump()
else:
question_profile = state["question_profile"]
question_profile_json = json.dumps(question_profile, ensure_ascii=False, indent=2)

# refined_input이 없는 경우 초기 사용자 입력 사용
refined_question = state.get("refined_input", state["messages"][0].content)
if hasattr(refined_question, "content"):
refined_question = refined_question.content

enriched_text = query_enrichment_chain.invoke(
input={
"refined_question": state["refined_input"],
"refined_question": refined_question,
"profiles": question_profile_json,
"related_tables": searched_tables_json,
}
Expand Down Expand Up @@ -207,3 +216,33 @@ def query_maker_node_with_db_guide(state: QueryMakerState):
state["generated_query"] = res.sql
state["messages"].append(res.explanation)
return state


# 노드 함수: QUERY_MAKER 노드 (refined_input 없이)
def query_maker_node_without_refiner(state: QueryMakerState):
"""
refined_input 없이 초기 사용자 입력만을 사용하여 SQL을 생성하는 노드입니다.

이 노드는 QUERY_REFINER 단계를 건너뛰고, 초기 사용자 입력, 프로파일 정보,
컨텍스트 보강 정보를 모두 활용하여 SQL을 생성합니다.
"""
# 컨텍스트 보강된 질문 (refined_input이 없는 경우 초기 입력 사용)
enriched_question = state.get("refined_input", state["messages"][0])

# enriched_question이 AIMessage인 경우 content 추출, 문자열인 경우 그대로 사용
if hasattr(enriched_question, "content"):
enriched_question_content = enriched_question.content
else:
enriched_question_content = str(enriched_question)

res = query_maker_chain.invoke(
input={
"user_input": [state["messages"][0].content],
"refined_input": [enriched_question_content],
"searched_tables": [json.dumps(state["searched_tables"])],
"user_database_env": [state["user_database_env"]],
}
)
state["generated_query"] = res
state["messages"].append(res)
return state
38 changes: 38 additions & 0 deletions llm_utils/graph_utils/simplified_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import json

from langgraph.graph import StateGraph, END
from llm_utils.graph_utils.base import (
QueryMakerState,
GET_TABLE_INFO,
PROFILE_EXTRACTION,
CONTEXT_ENRICHMENT,
QUERY_MAKER,
get_table_info_node,
profile_extraction_node,
context_enrichment_node,
query_maker_node_without_refiner,
)

"""
QUERY_REFINER 단계를 제거한 단순화된 워크플로우입니다.
GET_TABLE_INFO → PROFILE_EXTRACTION → CONTEXT_ENRICHMENT → QUERY_MAKER 순서로 실행됩니다.
초기 사용자 입력만을 사용하여 더 정확한 쿼리를 생성합니다.
"""

# StateGraph 생성 및 구성
builder = StateGraph(QueryMakerState)
builder.set_entry_point(GET_TABLE_INFO)

# 노드 추가
builder.add_node(GET_TABLE_INFO, get_table_info_node)
builder.add_node(PROFILE_EXTRACTION, profile_extraction_node)
builder.add_node(CONTEXT_ENRICHMENT, context_enrichment_node)
builder.add_node(QUERY_MAKER, query_maker_node_without_refiner)

# 기본 엣지 설정
builder.add_edge(GET_TABLE_INFO, PROFILE_EXTRACTION)
builder.add_edge(PROFILE_EXTRACTION, CONTEXT_ENRICHMENT)
builder.add_edge(CONTEXT_ENRICHMENT, QUERY_MAKER)

# QUERY_MAKER 노드 후 종료
builder.add_edge(QUERY_MAKER, END)
18 changes: 15 additions & 3 deletions llm_utils/query_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from llm_utils.graph_utils.enriched_graph import builder as enriched_builder
from llm_utils.graph_utils.basic_graph import builder as basic_builder
from llm_utils.graph_utils.simplified_graph import builder as simplified_builder
from llm_utils.llm_response_parser import LLMResponseParser

logger = logging.getLogger(__name__)
Expand All @@ -25,6 +26,7 @@ def execute_query(
top_n: int = 5,
device: str = "cpu",
use_enriched_graph: bool = False,
use_simplified_graph: bool = False,
session_state: Optional[Union[Dict[str, Any], Any]] = None,
) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -52,19 +54,29 @@ def execute_query(
"""

logger.info("Processing query: %s", query)
logger.info("Using %s graph", "enriched" if use_enriched_graph else "basic")

# 그래프 선택
if use_simplified_graph:
graph_type = "simplified"
graph_builder = simplified_builder
elif use_enriched_graph:
graph_type = "enriched"
graph_builder = enriched_builder
else:
graph_type = "basic"
graph_builder = basic_builder

logger.info("Using %s graph", graph_type)

# 그래프 선택 및 컴파일
if session_state is not None:
# Streamlit 환경: 세션 상태에서 그래프 재사용
graph = session_state.get("graph")
if graph is None:
graph_builder = enriched_builder if use_enriched_graph else basic_builder
graph = graph_builder.compile()
session_state["graph"] = graph
else:
# CLI 환경: 매번 새로운 그래프 컴파일
graph_builder = enriched_builder if use_enriched_graph else basic_builder
graph = graph_builder.compile()

# 그래프 실행
Expand Down
14 changes: 0 additions & 14 deletions pyproject.toml

This file was deleted.