diff --git a/cli/__init__.py b/cli/__init__.py index 5670638..b492dd1 100644 --- a/cli/__init__.py +++ b/cli/__init__.py @@ -189,6 +189,11 @@ def run_streamlit_cli_command(port: int) -> None: is_flag=True, help="확장된 그래프(프로파일 추출 + 컨텍스트 보강) 사용 여부", ) +@click.option( + "--use-simplified-graph", + is_flag=True, + help="단순화된 그래프(QUERY_REFINER 제거) 사용 여부", +) def query_command( question: str, database_env: str, @@ -196,6 +201,7 @@ def query_command( top_n: int, device: str, use_enriched_graph: bool, + use_simplified_graph: bool, ) -> None: """ 자연어 질문을 SQL 쿼리로 변환하여 출력하는 명령어입니다. @@ -227,6 +233,7 @@ def query_command( top_n=top_n, device=device, use_enriched_graph=use_enriched_graph, + use_simplified_graph=use_simplified_graph, ) # SQL 추출 및 출력 diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 8293141..9aaa47b 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -20,6 +20,7 @@ from llm_utils.token_utils import TokenUtils from llm_utils.graph_utils.enriched_graph import builder as enriched_builder from llm_utils.graph_utils.basic_graph import builder +from llm_utils.graph_utils.simplified_graph import builder as simplified_builder TITLE = "Lang2SQL" @@ -71,6 +72,7 @@ def execute_query( top_n=top_n, device=device, use_enriched_graph=st.session_state.get("use_enriched", False), + use_simplified_graph=st.session_state.get("use_simplified", False), session_state=st.session_state, ) @@ -219,29 +221,53 @@ def should_show(_key: str) -> bool: st.title(TITLE) # 워크플로우 선택(UI) +st.sidebar.markdown("### 워크플로우 선택") use_enriched = st.sidebar.checkbox( "프로파일 추출 & 컨텍스트 보강 워크플로우 사용", value=False ) +use_simplified = st.sidebar.checkbox( + "단순화된 워크플로우 사용 (QUERY_REFINER 제거)", value=False +) # 세션 상태 초기화 if ( "graph" not in st.session_state or st.session_state.get("use_enriched") != use_enriched + or st.session_state.get("use_simplified") != use_simplified ): - graph_builder = enriched_builder if use_enriched else builder - st.session_state["graph"] = graph_builder.compile() + # 그래프 선택 로직 + if use_simplified: + graph_builder = simplified_builder + graph_type = "단순화된" + elif use_enriched: + graph_builder = enriched_builder + graph_type = "확장된" + else: + graph_builder = builder + graph_type = "기본" - # 프로파일 추출 & 컨텍스트 보강 그래프 + st.session_state["graph"] = graph_builder.compile() st.session_state["use_enriched"] = use_enriched - st.info("Lang2SQL이 성공적으로 시작되었습니다.") + st.session_state["use_simplified"] = use_simplified + st.info(f"Lang2SQL이 성공적으로 시작되었습니다. ({graph_type} 워크플로우)") # 새로고침 버튼 추가 if st.sidebar.button("Lang2SQL 새로고침"): - graph_builder = ( - enriched_builder if st.session_state.get("use_enriched") else builder - ) + # 그래프 선택 로직 + if st.session_state.get("use_simplified"): + graph_builder = simplified_builder + graph_type = "단순화된" + elif st.session_state.get("use_enriched"): + graph_builder = enriched_builder + graph_type = "확장된" + else: + graph_builder = builder + graph_type = "기본" + st.session_state["graph"] = graph_builder.compile() - st.sidebar.success("Lang2SQL이 성공적으로 새로고침되었습니다.") + st.sidebar.success( + f"Lang2SQL이 성공적으로 새로고침되었습니다. ({graph_type} 워크플로우)" + ) user_query = st.text_area( "쿼리를 입력하세요:", diff --git a/llm_utils/graph_utils/__init__.py b/llm_utils/graph_utils/__init__.py index 41b539b..794578d 100644 --- a/llm_utils/graph_utils/__init__.py +++ b/llm_utils/graph_utils/__init__.py @@ -18,10 +18,12 @@ query_refiner_with_profile_node, context_enrichment_node, query_maker_node_with_db_guide, + query_maker_node_without_refiner, ) from .basic_graph import builder as basic_builder from .enriched_graph import builder as enriched_builder +from .simplified_graph import builder as simplified_builder __all__ = [ # 상태 및 노드 식별자 @@ -39,7 +41,9 @@ "query_refiner_with_profile_node", "context_enrichment_node", "query_maker_node_with_db_guide", + "query_maker_node_without_refiner", # 그래프 빌더들 "basic_builder", "enriched_builder", + "simplified_builder", ] diff --git a/llm_utils/graph_utils/base.py b/llm_utils/graph_utils/base.py index d725651..c00c490 100644 --- a/llm_utils/graph_utils/base.py +++ b/llm_utils/graph_utils/base.py @@ -140,12 +140,21 @@ def context_enrichment_node(state: QueryMakerState): searched_tables = state["searched_tables"] searched_tables_json = json.dumps(searched_tables, ensure_ascii=False, indent=2) - question_profile = state["question_profile"].model_dump() + # question_profile이 BaseModel인 경우 model_dump() 사용, dict인 경우 그대로 사용 + if hasattr(state["question_profile"], "model_dump"): + question_profile = state["question_profile"].model_dump() + else: + question_profile = state["question_profile"] question_profile_json = json.dumps(question_profile, ensure_ascii=False, indent=2) + # refined_input이 없는 경우 초기 사용자 입력 사용 + refined_question = state.get("refined_input", state["messages"][0].content) + if hasattr(refined_question, "content"): + refined_question = refined_question.content + enriched_text = query_enrichment_chain.invoke( input={ - "refined_question": state["refined_input"], + "refined_question": refined_question, "profiles": question_profile_json, "related_tables": searched_tables_json, } @@ -207,3 +216,33 @@ def query_maker_node_with_db_guide(state: QueryMakerState): state["generated_query"] = res.sql state["messages"].append(res.explanation) return state + + +# 노드 함수: QUERY_MAKER 노드 (refined_input 없이) +def query_maker_node_without_refiner(state: QueryMakerState): + """ + refined_input 없이 초기 사용자 입력만을 사용하여 SQL을 생성하는 노드입니다. + + 이 노드는 QUERY_REFINER 단계를 건너뛰고, 초기 사용자 입력, 프로파일 정보, + 컨텍스트 보강 정보를 모두 활용하여 SQL을 생성합니다. + """ + # 컨텍스트 보강된 질문 (refined_input이 없는 경우 초기 입력 사용) + enriched_question = state.get("refined_input", state["messages"][0]) + + # enriched_question이 AIMessage인 경우 content 추출, 문자열인 경우 그대로 사용 + if hasattr(enriched_question, "content"): + enriched_question_content = enriched_question.content + else: + enriched_question_content = str(enriched_question) + + res = query_maker_chain.invoke( + input={ + "user_input": [state["messages"][0].content], + "refined_input": [enriched_question_content], + "searched_tables": [json.dumps(state["searched_tables"])], + "user_database_env": [state["user_database_env"]], + } + ) + state["generated_query"] = res + state["messages"].append(res) + return state diff --git a/llm_utils/graph_utils/simplified_graph.py b/llm_utils/graph_utils/simplified_graph.py new file mode 100644 index 0000000..a241537 --- /dev/null +++ b/llm_utils/graph_utils/simplified_graph.py @@ -0,0 +1,38 @@ +import json + +from langgraph.graph import StateGraph, END +from llm_utils.graph_utils.base import ( + QueryMakerState, + GET_TABLE_INFO, + PROFILE_EXTRACTION, + CONTEXT_ENRICHMENT, + QUERY_MAKER, + get_table_info_node, + profile_extraction_node, + context_enrichment_node, + query_maker_node_without_refiner, +) + +""" +QUERY_REFINER 단계를 제거한 단순화된 워크플로우입니다. +GET_TABLE_INFO → PROFILE_EXTRACTION → CONTEXT_ENRICHMENT → QUERY_MAKER 순서로 실행됩니다. +초기 사용자 입력만을 사용하여 더 정확한 쿼리를 생성합니다. +""" + +# StateGraph 생성 및 구성 +builder = StateGraph(QueryMakerState) +builder.set_entry_point(GET_TABLE_INFO) + +# 노드 추가 +builder.add_node(GET_TABLE_INFO, get_table_info_node) +builder.add_node(PROFILE_EXTRACTION, profile_extraction_node) +builder.add_node(CONTEXT_ENRICHMENT, context_enrichment_node) +builder.add_node(QUERY_MAKER, query_maker_node_without_refiner) + +# 기본 엣지 설정 +builder.add_edge(GET_TABLE_INFO, PROFILE_EXTRACTION) +builder.add_edge(PROFILE_EXTRACTION, CONTEXT_ENRICHMENT) +builder.add_edge(CONTEXT_ENRICHMENT, QUERY_MAKER) + +# QUERY_MAKER 노드 후 종료 +builder.add_edge(QUERY_MAKER, END) diff --git a/llm_utils/query_executor.py b/llm_utils/query_executor.py index 7a87e3c..68b16bc 100644 --- a/llm_utils/query_executor.py +++ b/llm_utils/query_executor.py @@ -12,6 +12,7 @@ from llm_utils.graph_utils.enriched_graph import builder as enriched_builder from llm_utils.graph_utils.basic_graph import builder as basic_builder +from llm_utils.graph_utils.simplified_graph import builder as simplified_builder from llm_utils.llm_response_parser import LLMResponseParser logger = logging.getLogger(__name__) @@ -25,6 +26,7 @@ def execute_query( top_n: int = 5, device: str = "cpu", use_enriched_graph: bool = False, + use_simplified_graph: bool = False, session_state: Optional[Union[Dict[str, Any], Any]] = None, ) -> Dict[str, Any]: """ @@ -52,19 +54,29 @@ def execute_query( """ logger.info("Processing query: %s", query) - logger.info("Using %s graph", "enriched" if use_enriched_graph else "basic") + + # 그래프 선택 + if use_simplified_graph: + graph_type = "simplified" + graph_builder = simplified_builder + elif use_enriched_graph: + graph_type = "enriched" + graph_builder = enriched_builder + else: + graph_type = "basic" + graph_builder = basic_builder + + logger.info("Using %s graph", graph_type) # 그래프 선택 및 컴파일 if session_state is not None: # Streamlit 환경: 세션 상태에서 그래프 재사용 graph = session_state.get("graph") if graph is None: - graph_builder = enriched_builder if use_enriched_graph else basic_builder graph = graph_builder.compile() session_state["graph"] = graph else: # CLI 환경: 매번 새로운 그래프 컴파일 - graph_builder = enriched_builder if use_enriched_graph else basic_builder graph = graph_builder.compile() # 그래프 실행 diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index f72e3f5..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,14 +0,0 @@ -[tool.black] -line-length = 88 -target-version = ['py311'] -include = '\.pyi?$' -exclude = ''' -( - /( - \.git - | \.venv - | build - | dist - )/ -) -'''